def test_save_best_state_callback(save_path): n_latent = 5 adata = synthetic_iid() SCVI.setup_anndata(adata, batch_key="batch", labels_key="labels") model = SCVI(adata, n_latent=n_latent) callbacks = [SaveBestState(verbose=True)] model.train(3, check_val_every_n_epoch=1, train_size=0.5, callbacks=callbacks)
def test_scvi_library_size_update(save_path): n_latent = 5 adata1 = synthetic_iid() SCVI.setup_anndata(adata1, batch_key="batch", labels_key="labels") model = SCVI(adata1, n_latent=n_latent, use_observed_lib_size=False) assert (getattr(model.module, "library_log_means", None) is not None and model.module.library_log_means.shape == (1, 2) and model.module.library_log_means.count_nonzero().item() == 2) assert getattr( model.module, "library_log_vars", None) is not None and model.module.library_log_vars.shape == (1, 2) model.train(1, check_val_every_n_epoch=1) dir_path = os.path.join(save_path, "saved_model/") model.save(dir_path, overwrite=True) # also test subset var option adata2 = synthetic_iid(n_genes=110) adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories( ["batch_2", "batch_3"]) model2 = SCVI.load_query_data(adata2, dir_path, inplace_subset_query_vars=True) assert (getattr(model2.module, "library_log_means", None) is not None and model2.module.library_log_means.shape == (1, 4) and model2.module.library_log_means[:, :2].equal( model.module.library_log_means) and model2.module.library_log_means.count_nonzero().item() == 4) assert (getattr(model2.module, "library_log_vars", None) is not None and model2.module.library_log_vars.shape == (1, 4) and model2.module.library_log_vars[:, :2].equal( model.module.library_log_vars))
def test_multiple_covariates_scvi(save_path): adata = synthetic_iid() adata.obs["cont1"] = np.random.normal(size=(adata.shape[0], )) adata.obs["cont2"] = np.random.normal(size=(adata.shape[0], )) adata.obs["cat1"] = np.random.randint(0, 5, size=(adata.shape[0], )) adata.obs["cat2"] = np.random.randint(0, 5, size=(adata.shape[0], )) SCVI.setup_anndata( adata, batch_key="batch", labels_key="labels", continuous_covariate_keys=["cont1", "cont2"], categorical_covariate_keys=["cat1", "cat2"], ) m = SCVI(adata) m.train(1) m = SCANVI(adata, unlabeled_category="Unknown") m.train(1) TOTALVI.setup_anndata( adata, batch_key="batch", protein_expression_obsm_key="protein_expression", protein_names_uns_key="protein_names", continuous_covariate_keys=["cont1", "cont2"], categorical_covariate_keys=["cat1", "cat2"], ) m = TOTALVI(adata) m.train(1)
def unsupervised_training_one_epoch( adata: AnnData, run_setup_anndata: bool = True, batch_key: Optional[str] = None, labels_key: Optional[str] = None, ): if run_setup_anndata: SCVI.setup_anndata(adata, batch_key=batch_key, labels_key=labels_key) m = SCVI(adata) m.train(1, train_size=0.4, use_gpu=use_gpu)
def test_solo_multiple_batch(save_path): n_latent = 5 adata = synthetic_iid() adata.layers["my_layer"] = adata.X.copy() SCVI.setup_anndata(adata, layer="my_layer", batch_key="batch") model = SCVI(adata, n_latent=n_latent) model.train(1, check_val_every_n_epoch=1, train_size=0.5) solo = SOLO.from_scvi_model(model, restrict_to_batch="batch_0") solo.train(1, check_val_every_n_epoch=1, train_size=0.9) assert "validation_loss" in solo.history.keys() solo.predict()
def test_early_stopping(): n_epochs = 100 adata = synthetic_iid(run_setup_anndata=False, ) SCVI.setup_anndata( adata, batch_key="batch", labels_key="labels", ) model = SCVI(adata) model.train(n_epochs, early_stopping=True, plan_kwargs=dict(lr=0)) assert len(model.history["elbo_train"]) < n_epochs
def test_backed_anndata_scvi(save_path): adata = scvi.data.synthetic_iid() path = os.path.join(save_path, "test_data.h5ad") adata.write_h5ad(path) adata = anndata.read_h5ad(path, backed="r+") SCVI.setup_anndata(adata, batch_key="batch") model = SCVI(adata, n_latent=5) model.train(1, train_size=0.5) assert model.is_trained is True z = model.get_latent_representation() assert z.shape == (adata.shape[0], 5) model.get_elbo()
def test_set_seed(save_path): scvi.settings.seed = 1 n_latent = 5 adata = synthetic_iid() SCVI.setup_anndata(adata, batch_key="batch", labels_key="labels") model1 = SCVI(adata, n_latent=n_latent) model1.train(1) scvi.settings.seed = 1 model2 = SCVI(adata, n_latent=n_latent) model2.train(1) assert torch.equal( model1.module.z_encoder.encoder.fc_layers[0][0].weight, model2.module.z_encoder.encoder.fc_layers[0][0].weight, )
def test_scvi_sparse(save_path): n_latent = 5 adata = synthetic_iid(run_setup_anndata=False) adata.X = csr_matrix(adata.X) SCVI.setup_anndata(adata) model = SCVI(adata, n_latent=n_latent) model.train(1, train_size=0.5) assert model.is_trained is True z = model.get_latent_representation() assert z.shape == (adata.shape[0], n_latent) model.get_elbo() model.get_marginal_ll(n_mc_samples=3) model.get_reconstruction_error() model.get_normalized_expression() model.differential_expression(groupby="labels", group1="label_1")
def test_solo(save_path): n_latent = 5 adata = synthetic_iid(run_setup_anndata=False) SCVI.setup_anndata(adata) model = SCVI(adata, n_latent=n_latent) model.train(1, check_val_every_n_epoch=1, train_size=0.5) solo = SOLO.from_scvi_model(model) solo.train(1, check_val_every_n_epoch=1, train_size=0.9) assert "validation_loss" in solo.history.keys() solo.predict() bdata = synthetic_iid(run_setup_anndata=False) solo = SOLO.from_scvi_model(model, bdata) solo.train(1, check_val_every_n_epoch=1, train_size=0.9) assert "validation_loss" in solo.history.keys() solo.predict()
def test_new_setup_compat(): adata = synthetic_iid() adata.obs["cat1"] = np.random.randint(0, 5, size=(adata.shape[0], )) adata.obs["cat2"] = np.random.randint(0, 5, size=(adata.shape[0], )) adata.obs["cont1"] = np.random.normal(size=(adata.shape[0], )) adata.obs["cont2"] = np.random.normal(size=(adata.shape[0], )) adata2 = adata.copy() adata3 = adata.copy() SCVI.setup_anndata( adata, batch_key="batch", labels_key="labels", categorical_covariate_keys=["cat1", "cat2"], continuous_covariate_keys=["cont1", "cont2"], ) model = SCVI(adata) adata_manager = model.adata_manager model.view_anndata_setup(hide_state_registries=True) field_registries = adata_manager.registry[_constants._FIELD_REGISTRIES_KEY] field_registries_legacy_subset = { k: v for k, v in field_registries.items() if k in LEGACY_REGISTRY_KEYS } # Backwards compatibility test. adata2_manager = manager_from_setup_dict(SCVI, adata2, LEGACY_SETUP_DICT) np.testing.assert_equal( field_registries_legacy_subset, adata2_manager.registry[_constants._FIELD_REGISTRIES_KEY], ) # Test transfer. adata3_manager = adata_manager.transfer_setup(adata3) np.testing.assert_equal( field_registries, adata3_manager.registry[_constants._FIELD_REGISTRIES_KEY], )
def test_device_backed_data_splitter(): a = synthetic_iid() SCVI.setup_anndata(a, batch_key="batch", labels_key="labels") model = SCVI(a, n_latent=5) adata_manager = model.adata_manager # test leaving validataion_size empty works ds = DeviceBackedDataSplitter(adata_manager, train_size=1.0, use_gpu=None) ds.setup() train_dl = ds.train_dataloader() ds.val_dataloader() loaded_x = next(iter(train_dl))["X"] assert len(loaded_x) == a.shape[0] np.testing.assert_array_equal(loaded_x.cpu().numpy(), a.X) training_plan = TrainingPlan(model.module, len(ds.train_idx)) runner = TrainRunner( model, training_plan=training_plan, data_splitter=ds, max_epochs=1, use_gpu=None, ) runner()
def from_scvi_model( cls, scvi_model: SCVI, adata: Optional[AnnData] = None, restrict_to_batch: Optional[str] = None, doublet_ratio: int = 2, **classifier_kwargs, ): """ Instantiate a SOLO model from an scvi model. Parameters ---------- scvi_model Pre-trained model of :class:`~scvi.model.SCVI`. The adata object used to initialize this model should have only been setup with count data, and optionally a `batch_key`; i.e., no extra covariates or labels, etc. adata Optional anndata to use that is compatible with scvi_model. restrict_to_batch Batch category in `batch_key` used to setup adata for scvi_model to restrict Solo model to. This allows to train a Solo model on one batch of a scvi_model that was trained on multiple batches. doublet_ratio Ratio of generated doublets to produce relative to number of cells in adata or length of indices, if not `None`. **classifier_kwargs Keyword args for :class:`~scvi.module.Classifier` Returns ------- SOLO model """ _validate_scvi_model(scvi_model, restrict_to_batch=restrict_to_batch) orig_adata_manager = scvi_model.adata_manager orig_batch_key = orig_adata_manager.get_state_registry( REGISTRY_KEYS.BATCH_KEY).original_key if adata is not None: adata_manager = orig_adata_manager.transfer_setup(adata) cls.register_manager(adata_manager) else: adata_manager = orig_adata_manager adata = adata_manager.adata if restrict_to_batch is not None: batch_mask = adata.obs[orig_batch_key] == restrict_to_batch if np.sum(batch_mask) == 0: raise ValueError( "Batch category given to restrict_to_batch not found.\n" + "Available categories: {}".format( adata.obs[orig_batch_key].astype( "category").cat.categories)) # indices in adata with restrict_to_batch category batch_indices = np.where(batch_mask)[0] else: # use all indices batch_indices = None # anndata with only generated doublets doublet_adata = cls.create_doublets(adata_manager, indices=batch_indices, doublet_ratio=doublet_ratio) # if scvi wasn't trained with batch correction having the # zeros here does nothing. doublet_adata.obs[orig_batch_key] = ( restrict_to_batch if restrict_to_batch is not None else 0) # if model is using observed lib size, needs to get lib sample # which is just observed lib size on log scale give_mean_lib = not scvi_model.module.use_observed_lib_size # get latent representations and make input anndata latent_rep = scvi_model.get_latent_representation( adata, indices=batch_indices) lib_size = scvi_model.get_latent_library_size(adata, indices=batch_indices, give_mean=give_mean_lib) latent_adata = AnnData( np.concatenate([latent_rep, np.log(lib_size)], axis=1)) latent_adata.obs[LABELS_KEY] = "singlet" orig_obs_names = adata.obs_names latent_adata.obs_names = (orig_obs_names[batch_indices] if batch_indices is not None else orig_obs_names) logger.info("Creating doublets, preparing SOLO model.") f = io.StringIO() with redirect_stdout(f): scvi_model.setup_anndata(doublet_adata, batch_key=orig_batch_key) doublet_latent_rep = scvi_model.get_latent_representation( doublet_adata) doublet_lib_size = scvi_model.get_latent_library_size( doublet_adata, give_mean=give_mean_lib) doublet_adata = AnnData( np.concatenate([doublet_latent_rep, np.log(doublet_lib_size)], axis=1)) doublet_adata.obs[LABELS_KEY] = "doublet" full_adata = latent_adata.concatenate(doublet_adata) cls.setup_anndata(full_adata, labels_key=LABELS_KEY) return cls(full_adata, **classifier_kwargs)
def test_scvi(save_path): n_latent = 5 adata = synthetic_iid(run_setup_anndata=False) SCVI.setup_anndata( adata, batch_key="batch", labels_key="labels", ) # Test with observed lib size. adata = synthetic_iid(run_setup_anndata=False) SCVI.setup_anndata( adata, batch_key="batch", labels_key="labels", ) model = SCVI(adata, n_latent=n_latent) model.train(1, check_val_every_n_epoch=1, train_size=0.5) model = SCVI(adata, n_latent=n_latent, var_activation=Softplus(), use_observed_lib_size=False) model.train(1, check_val_every_n_epoch=1, train_size=0.5) model.train(1, check_val_every_n_epoch=1, train_size=0.5) # tests __repr__ print(model) assert model.is_trained is True z = model.get_latent_representation() assert z.shape == (adata.shape[0], n_latent) assert len(model.history["elbo_train"]) == 2 model.get_elbo() model.get_marginal_ll(n_mc_samples=3) model.get_reconstruction_error() model.get_normalized_expression(transform_batch="batch_1") adata2 = synthetic_iid() model.get_elbo(adata2) model.get_marginal_ll(adata2, n_mc_samples=3) model.get_reconstruction_error(adata2) latent = model.get_latent_representation(adata2, indices=[1, 2, 3]) assert latent.shape == (3, n_latent) denoised = model.get_normalized_expression(adata2) assert denoised.shape == adata.shape denoised = model.get_normalized_expression(adata2, indices=[1, 2, 3], transform_batch="batch_1") denoised = model.get_normalized_expression( adata2, indices=[1, 2, 3], transform_batch=["batch_0", "batch_1"]) assert denoised.shape == (3, adata2.n_vars) sample = model.posterior_predictive_sample(adata2) assert sample.shape == adata2.shape sample = model.posterior_predictive_sample(adata2, indices=[1, 2, 3], gene_list=["1", "2"]) assert sample.shape == (3, 2) sample = model.posterior_predictive_sample(adata2, indices=[1, 2, 3], gene_list=["1", "2"], n_samples=3) assert sample.shape == (3, 2, 3) model.get_feature_correlation_matrix(correlation_type="pearson") model.get_feature_correlation_matrix( adata2, indices=[1, 2, 3], correlation_type="spearman", rna_size_factor=500, n_samples=5, ) model.get_feature_correlation_matrix( adata2, indices=[1, 2, 3], correlation_type="spearman", rna_size_factor=500, n_samples=5, transform_batch=["batch_0", "batch_1"], ) params = model.get_likelihood_parameters() assert params["mean"].shape == adata.shape assert (params["mean"].shape == params["dispersions"].shape == params["dropout"].shape) params = model.get_likelihood_parameters(adata2, indices=[1, 2, 3]) assert params["mean"].shape == (3, adata.n_vars) params = model.get_likelihood_parameters(adata2, indices=[1, 2, 3], n_samples=3, give_mean=True) assert params["mean"].shape == (3, adata.n_vars) model.get_latent_library_size() model.get_latent_library_size(adata2, indices=[1, 2, 3]) # test transfer_anndata_setup adata2 = synthetic_iid(run_setup_anndata=False) transfer_anndata_setup(adata, adata2) model.get_elbo(adata2) # test automatic transfer_anndata_setup + on a view adata = synthetic_iid() model = SCVI(adata) adata2 = synthetic_iid(run_setup_anndata=False) model.get_elbo(adata2[:10]) # test that we catch incorrect mappings adata = synthetic_iid() adata2 = synthetic_iid(run_setup_anndata=False) transfer_anndata_setup(adata, adata2) adata2.uns["_scvi"]["categorical_mappings"]["_scvi_labels"][ "mapping"] = np.array(["label_4", "label_0", "label_2"]) with pytest.raises(ValueError): model.get_elbo(adata2) # test that same mapping different order doesn't raise error adata = synthetic_iid() adata2 = synthetic_iid(run_setup_anndata=False) transfer_anndata_setup(adata, adata2) adata2.uns["_scvi"]["categorical_mappings"]["_scvi_labels"][ "mapping"] = np.array(["label_1", "label_0", "label_2"]) model.get_elbo(adata2) # should automatically transfer setup # test mismatched categories raises ValueError adata2 = synthetic_iid(run_setup_anndata=False) adata2.obs.labels.cat.rename_categories(["a", "b", "c"], inplace=True) with pytest.raises(ValueError): model.get_elbo(adata2) # test differential expression model.differential_expression(groupby="labels", group1="label_1") model.differential_expression(groupby="labels", group1="label_1", group2="label_2", mode="change") model.differential_expression(groupby="labels") model.differential_expression(idx1=[0, 1, 2], idx2=[3, 4, 5]) model.differential_expression(idx1=[0, 1, 2]) # transform batch works with all different types a = synthetic_iid(run_setup_anndata=False) batch = np.zeros(a.n_obs) batch[:64] += 1 a.obs["batch"] = batch _setup_anndata(a, batch_key="batch") m = SCVI(a) m.train(1, train_size=0.5) m.get_normalized_expression(transform_batch=1) m.get_normalized_expression(transform_batch=[0, 1]) # test get_likelihood_parameters() when dispersion=='gene-cell' model = SCVI(adata, dispersion="gene-cell") model.get_likelihood_parameters() # test train callbacks work a = synthetic_iid() m = scvi.model.SCVI(a) lr_monitor = LearningRateMonitor() m.train( callbacks=[lr_monitor], max_epochs=10, check_val_every_n_epoch=1, log_every_n_steps=1, plan_kwargs={"reduce_lr_on_plateau": True}, ) assert "lr-Adam" in m.history.keys()
def test_differential_computation(save_path): n_latent = 5 adata = synthetic_iid() SCVI.setup_anndata( adata, batch_key="batch", labels_key="labels", ) model = SCVI(adata, n_latent=n_latent) model.train(1) model_fn = partial(model.get_normalized_expression, return_numpy=True) dc = DifferentialComputation(model_fn, model.adata_manager) cell_idx1 = np.asarray(adata.obs.labels == "label_1") cell_idx2 = ~cell_idx1 dc.get_bayes_factors(cell_idx1, cell_idx2, mode="vanilla", use_permutation=True) res = dc.get_bayes_factors(cell_idx1, cell_idx2, mode="change", use_permutation=False) assert (res["delta"] == 0.5) and (res["pseudocounts"] == 0.0) res = dc.get_bayes_factors(cell_idx1, cell_idx2, mode="change", use_permutation=False, delta=None) dc.get_bayes_factors( cell_idx1, cell_idx2, mode="change", use_permutation=False, delta=None, pseudocounts=None, ) dc.get_bayes_factors(cell_idx1, cell_idx2, mode="change", cred_interval_lvls=[0.75]) delta = 0.5 def change_fn_test(x, y): return x - y def m1_domain_fn_test(samples): return np.abs(samples) >= delta dc.get_bayes_factors( cell_idx1, cell_idx2, mode="change", m1_domain_fn=m1_domain_fn_test, change_fn=change_fn_test, ) # should fail if just one batch with pytest.raises(ValueError): model.differential_expression(adata[:20], groupby="batch") # test view model.differential_expression(adata[adata.obs["labels"] == "label_1"], groupby="batch") # Test query features obs_col, group1, _, = _prepare_obs( idx1="(labels == 'label_1') & (batch == 'batch_1')", idx2=None, adata=adata) assert (obs_col == group1 ).sum() == adata.obs.loc[lambda x: (x.labels == "label_1") & (x.batch == "batch_1")].shape[0] model.differential_expression(idx1="labels == 'label_1'", ) model.differential_expression( idx1="labels == 'label_1'", idx2="(labels == 'label_2') & (batch == 'batch_1')") # test that ints as group work a = synthetic_iid() SCVI.setup_anndata( a, batch_key="batch", labels_key="labels", ) a.obs["test"] = [0] * 200 + [1] * 200 model = SCVI(a) model.differential_expression(groupby="test", group1=0) # test that string but not as categorical work a = synthetic_iid() SCVI.setup_anndata( a, batch_key="batch", labels_key="labels", ) a.obs["test"] = ["0"] * 200 + ["1"] * 200 model = SCVI(a) model.differential_expression(groupby="test", group1="0")
def test_scanvi(save_path): adata = synthetic_iid() SCANVI.setup_anndata( adata, "label_0", batch_key="batch", labels_key="labels", ) model = SCANVI(adata, n_latent=10) model.train(1, train_size=0.5, check_val_every_n_epoch=1) logged_keys = model.history.keys() assert "elbo_validation" in logged_keys assert "reconstruction_loss_validation" in logged_keys assert "kl_local_validation" in logged_keys assert "elbo_train" in logged_keys assert "reconstruction_loss_train" in logged_keys assert "kl_local_train" in logged_keys assert "classification_loss_validation" in logged_keys adata2 = synthetic_iid() predictions = model.predict(adata2, indices=[1, 2, 3]) assert len(predictions) == 3 model.predict() df = model.predict(adata2, soft=True) assert isinstance(df, pd.DataFrame) model.predict(adata2, soft=True, indices=[1, 2, 3]) model.get_normalized_expression(adata2) model.differential_expression(groupby="labels", group1="label_1") model.differential_expression(groupby="labels", group1="label_1", group2="label_2") # test that all data labeled runs unknown_label = "asdf" a = scvi.data.synthetic_iid() scvi.model.SCANVI.setup_anndata(a, unknown_label, batch_key="batch", labels_key="labels") m = scvi.model.SCANVI(a) m.train(1) # test mix of labeled and unlabeled data unknown_label = "label_0" a = scvi.data.synthetic_iid() scvi.model.SCANVI.setup_anndata(a, unknown_label, batch_key="batch", labels_key="labels") m = scvi.model.SCANVI(a) m.train(1, train_size=0.9) # test from_scvi_model a = scvi.data.synthetic_iid() SCVI.setup_anndata( a, batch_key="batch", labels_key="labels", ) m = SCVI(a, use_observed_lib_size=False) a2 = scvi.data.synthetic_iid() scanvi_model = scvi.model.SCANVI.from_scvi_model(m, "label_0", adata=a2) # make sure the state_dicts are different objects for the two models assert scanvi_model.module.state_dict() is not m.module.state_dict() scanvi_pxr = scanvi_model.module.state_dict().get("px_r", None) scvi_pxr = m.module.state_dict().get("px_r", None) assert scanvi_pxr is not None and scvi_pxr is not None assert scanvi_pxr is not scvi_pxr scanvi_model.train(1) # Test without label groups scanvi_model = scvi.model.SCANVI.from_scvi_model(m, "label_0", use_labels_groups=False) scanvi_model.train(1) # test from_scvi_model with size_factor a = scvi.data.synthetic_iid() a.obs["size_factor"] = np.random.randint(1, 5, size=(a.shape[0], )) SCVI.setup_anndata(a, batch_key="batch", labels_key="labels", size_factor_key="size_factor") m = SCVI(a, use_observed_lib_size=False) a2 = scvi.data.synthetic_iid() a2.obs["size_factor"] = np.random.randint(1, 5, size=(a2.shape[0], )) scanvi_model = scvi.model.SCANVI.from_scvi_model(m, "label_0", adata=a2) scanvi_model.train(1)
def test_scvi_online_update(save_path): n_latent = 5 adata1 = synthetic_iid() SCVI.setup_anndata(adata1, batch_key="batch", labels_key="labels") model = SCVI(adata1, n_latent=n_latent) model.train(1, check_val_every_n_epoch=1) dir_path = os.path.join(save_path, "saved_model/") model.save(dir_path, overwrite=True) # also test subset var option adata2 = synthetic_iid(n_genes=110) adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories( ["batch_2", "batch_3"]) model2 = SCVI.load_query_data(adata2, dir_path, inplace_subset_query_vars=True) model2.train(max_epochs=1, plan_kwargs=dict(weight_decay=0.0)) model2.get_latent_representation() # encoder linear layer equal one = (model.module.z_encoder.encoder.fc_layers[0] [0].weight.detach().cpu().numpy()[:, :adata1.shape[1]]) two = (model2.module.z_encoder.encoder.fc_layers[0] [0].weight.detach().cpu().numpy()[:, :adata1.shape[1]]) np.testing.assert_equal(one, two) single_pass_for_online_update(model2) assert (np.sum(model2.module.z_encoder.encoder.fc_layers[0] [0].weight.grad.cpu().numpy()[:, :adata1.shape[1]]) == 0) # dispersion assert model2.module.px_r.requires_grad is False # library encoder linear layer assert model2.module.l_encoder.encoder.fc_layers[0][ 0].weight.requires_grad is True # 5 for n_latent, 4 for batches assert model2.module.decoder.px_decoder.fc_layers[0][0].weight.shape[ 1] == 9 # test options adata1 = synthetic_iid() SCVI.setup_anndata(adata1, batch_key="batch", labels_key="labels") model = SCVI( adata1, n_latent=n_latent, n_layers=2, encode_covariates=True, use_batch_norm="encoder", use_layer_norm="none", ) model.train(1, check_val_every_n_epoch=1) dir_path = os.path.join(save_path, "saved_model/") model.save(dir_path, overwrite=True) adata2 = synthetic_iid() adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories( ["batch_2", "batch_3"]) model2 = SCVI.load_query_data(adata2, dir_path, freeze_expression=True) model2.train(max_epochs=1, plan_kwargs=dict(weight_decay=0.0)) # deactivate no grad decorator model2.get_latent_representation() # pytorch lightning zeros the grad, so this will get a grad to inspect single_pass_for_online_update(model2) grad = model2.module.z_encoder.encoder.fc_layers[0][0].weight.grad.cpu( ).numpy() # expression part has zero grad assert np.sum(grad[:, :-4]) == 0 # categorical part has non-zero grad assert np.sum(grad[:, -4:]) != 0 grad = model2.module.decoder.px_decoder.fc_layers[0][0].weight.grad.cpu( ).numpy() # linear layer weight in decoder layer has non-zero grad assert np.sum(grad[:, :-4]) == 0 # do not freeze expression model3 = SCVI.load_query_data( adata2, dir_path, freeze_expression=False, freeze_batchnorm_encoder=True, freeze_decoder_first_layer=False, ) model3.train(max_epochs=1) model3.get_latent_representation() assert model3.module.z_encoder.encoder.fc_layers[0][1].momentum == 0 # batch norm weight in encoder layer assert model3.module.z_encoder.encoder.fc_layers[0][ 1].weight.requires_grad is False single_pass_for_online_update(model3) grad = model3.module.z_encoder.encoder.fc_layers[0][0].weight.grad.cpu( ).numpy() # linear layer weight in encoder layer has non-zero grad assert np.sum(grad[:, :-4]) != 0 grad = model3.module.decoder.px_decoder.fc_layers[0][0].weight.grad.cpu( ).numpy() # linear layer weight in decoder layer has non-zero grad assert np.sum(grad[:, :-4]) != 0 # do not freeze batchnorm model3 = SCVI.load_query_data(adata2, dir_path, freeze_batchnorm_encoder=False) model3.train(max_epochs=1) model3.get_latent_representation()