def test_scvi_library_size_update(save_path): n_latent = 5 adata1 = synthetic_iid() SCVI.setup_anndata(adata1, batch_key="batch", labels_key="labels") model = SCVI(adata1, n_latent=n_latent, use_observed_lib_size=False) assert (getattr(model.module, "library_log_means", None) is not None and model.module.library_log_means.shape == (1, 2) and model.module.library_log_means.count_nonzero().item() == 2) assert getattr( model.module, "library_log_vars", None) is not None and model.module.library_log_vars.shape == (1, 2) model.train(1, check_val_every_n_epoch=1) dir_path = os.path.join(save_path, "saved_model/") model.save(dir_path, overwrite=True) # also test subset var option adata2 = synthetic_iid(n_genes=110) adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories( ["batch_2", "batch_3"]) model2 = SCVI.load_query_data(adata2, dir_path, inplace_subset_query_vars=True) assert (getattr(model2.module, "library_log_means", None) is not None and model2.module.library_log_means.shape == (1, 4) and model2.module.library_log_means[:, :2].equal( model.module.library_log_means) and model2.module.library_log_means.count_nonzero().item() == 4) assert (getattr(model2.module, "library_log_vars", None) is not None and model2.module.library_log_vars.shape == (1, 4) and model2.module.library_log_vars[:, :2].equal( model.module.library_log_vars))
def test_scvi_online_update(save_path): n_latent = 5 adata1 = synthetic_iid() model = SCVI(adata1, n_latent=n_latent) model.train(1, check_val_every_n_epoch=1) dir_path = os.path.join(save_path, "saved_model/") model.save(dir_path, overwrite=True) # also test subset var option adata2 = synthetic_iid(run_setup_anndata=False, n_genes=110) adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories(["batch_2", "batch_3"]) model2 = SCVI.load_query_data(adata2, dir_path, inplace_subset_query_vars=True) model2.train(max_epochs=1, plan_kwargs=dict(weight_decay=0.0)) model2.get_latent_representation() # encoder linear layer equal one = ( model.module.z_encoder.encoder.fc_layers[0][0] .weight.detach() .cpu() .numpy()[:, : adata1.shape[1]] ) two = ( model2.module.z_encoder.encoder.fc_layers[0][0] .weight.detach() .cpu() .numpy()[:, : adata1.shape[1]] ) np.testing.assert_equal(one, two) assert ( np.sum( model2.module.z_encoder.encoder.fc_layers[0][0] .weight.grad.cpu() .numpy()[:, : adata1.shape[1]] ) == 0 ) # dispersion assert model2.module.px_r.requires_grad is False # library encoder linear layer assert model2.module.l_encoder.encoder.fc_layers[0][0].weight.requires_grad is True # 5 for n_latent, 4 for batches assert model2.module.decoder.px_decoder.fc_layers[0][0].weight.shape[1] == 9 # test options adata1 = synthetic_iid() model = SCVI( adata1, n_latent=n_latent, n_layers=2, encode_covariates=True, use_batch_norm="encoder", use_layer_norm="none", ) model.train(1, check_val_every_n_epoch=1) dir_path = os.path.join(save_path, "saved_model/") model.save(dir_path, overwrite=True) adata2 = synthetic_iid(run_setup_anndata=False) adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories(["batch_2", "batch_3"]) model2 = SCVI.load_query_data(adata2, dir_path, freeze_expression=True) model2.train(max_epochs=1, plan_kwargs=dict(weight_decay=0.0)) # deactivate no grad decorator model2.get_latent_representation() # pytorch lightning zeros the grad, so this will get a grad to inspect single_pass_for_online_update(model2) grad = model2.module.z_encoder.encoder.fc_layers[0][0].weight.grad.cpu().numpy() # expression part has zero grad assert np.sum(grad[:, :-4]) == 0 # categorical part has non-zero grad assert np.sum(grad[:, -4:]) != 0 grad = model2.module.decoder.px_decoder.fc_layers[0][0].weight.grad.cpu().numpy() # linear layer weight in decoder layer has non-zero grad assert np.sum(grad[:, :-4]) == 0 # do not freeze expression model3 = SCVI.load_query_data( adata2, dir_path, freeze_expression=False, freeze_batchnorm_encoder=True, freeze_decoder_first_layer=False, ) model3.train(max_epochs=1) model3.get_latent_representation() assert model3.module.z_encoder.encoder.fc_layers[0][1].momentum == 0 # batch norm weight in encoder layer assert model3.module.z_encoder.encoder.fc_layers[0][1].weight.requires_grad is False single_pass_for_online_update(model3) grad = model3.module.z_encoder.encoder.fc_layers[0][0].weight.grad.cpu().numpy() # linear layer weight in encoder layer has non-zero grad assert np.sum(grad[:, :-4]) != 0 grad = model3.module.decoder.px_decoder.fc_layers[0][0].weight.grad.cpu().numpy() # linear layer weight in decoder layer has non-zero grad assert np.sum(grad[:, :-4]) != 0 # do not freeze batchnorm model3 = SCVI.load_query_data(adata2, dir_path, freeze_batchnorm_encoder=False) model3.train(max_epochs=1) model3.get_latent_representation()