def test_peakvi(): data = synthetic_iid() vae = PEAKVI( data, model_depth=False, ) vae.train(1, save_best=False) vae = PEAKVI( data, region_factors=False, ) vae.train(1, save_best=False) vae = PEAKVI( data, ) vae.train(3) vae.get_elbo(indices=vae.validation_indices) vae.get_accessibility_estimates() vae.get_accessibility_estimates(normalize_cells=True) vae.get_accessibility_estimates(normalize_regions=True) vae.get_library_size_factors() vae.get_region_factors() vae.get_reconstruction_error(indices=vae.validation_indices) vae.get_latent_representation() vae.differential_accessibility(groupby="labels", group1="label_1")
def test_peakvi_online_update(save_path): n_latent = 5 adata1 = synthetic_iid() model = PEAKVI(adata1, n_latent=n_latent) model.train(1, save_best=False) dir_path = os.path.join(save_path, "saved_model/") model.save(dir_path, overwrite=True) # also test subset var option adata2 = synthetic_iid(run_setup_anndata=False) adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories(["batch_2", "batch_3"]) model2 = PEAKVI.load_query_data(adata2, dir_path) model2.train(max_epochs=1, weight_decay=0.0, save_best=False) model2.get_latent_representation() # encoder linear layer equal for peak features one = ( model.module.z_encoder.encoder.fc_layers[0][0] .weight.detach() .cpu() .numpy()[:, : adata1.shape[1]] ) two = ( model2.module.z_encoder.encoder.fc_layers[0][0] .weight.detach() .cpu() .numpy()[:, : adata1.shape[1]] ) np.testing.assert_equal(one, two) assert ( np.sum( model2.module.z_encoder.encoder.fc_layers[0][0] .weight.grad.cpu() .numpy()[:, : adata1.shape[1]] ) == 0 ) # test options adata1 = synthetic_iid() model = PEAKVI( adata1, n_latent=n_latent, encode_covariates=True, ) model.train(1, check_val_every_n_epoch=1, save_best=False) dir_path = os.path.join(save_path, "saved_model/") model.save(dir_path, overwrite=True) adata2 = synthetic_iid(run_setup_anndata=False) adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories(["batch_2", "batch_3"]) model2 = PEAKVI.load_query_data(adata2, dir_path, freeze_expression=True) model2.train(max_epochs=1, weight_decay=0.0, save_best=False) # deactivate no grad decorator model2.get_latent_representation() # pytorch lightning zeros the grad, so this will get a grad to inspect single_pass_for_online_update(model2) grad = model2.module.z_encoder.encoder.fc_layers[0][0].weight.grad.cpu().numpy() # expression part has zero grad assert np.sum(grad[:, :-4]) == 0 # categorical part has non-zero grad assert np.count_nonzero(grad[:, -4:]) > 0 # do not freeze expression model3 = PEAKVI.load_query_data( adata2, dir_path, freeze_expression=False, freeze_decoder_first_layer=False, ) model3.train(max_epochs=1, save_best=False, weight_decay=0.0) model3.get_latent_representation() single_pass_for_online_update(model3) grad = model3.module.z_encoder.encoder.fc_layers[0][0].weight.grad.cpu().numpy() # linear layer weight in encoder layer has non-zero grad assert np.count_nonzero(grad[:, :-4]) != 0 grad = model3.module.z_decoder.px_decoder.fc_layers[0][0].weight.grad.cpu().numpy() # linear layer weight in decoder layer has non-zero grad assert np.count_nonzero(grad[:, :-4]) != 0