Esempio n. 1
0
def test_peakvi():
    data = synthetic_iid(run_setup_anndata=False)
    PEAKVI.setup_anndata(
        data,
        batch_key="batch",
    )
    vae = PEAKVI(
        data,
        model_depth=False,
    )
    vae.train(1, save_best=False)
    vae = PEAKVI(
        data,
        region_factors=False,
    )
    vae.train(1, save_best=False)
    vae = PEAKVI(data, )
    vae.train(3)
    vae.get_elbo(indices=vae.validation_indices)
    vae.get_accessibility_estimates()
    vae.get_accessibility_estimates(normalize_cells=True)
    vae.get_accessibility_estimates(normalize_regions=True)
    vae.get_library_size_factors()
    vae.get_region_factors()
    vae.get_reconstruction_error(indices=vae.validation_indices)
    vae.get_latent_representation()
    vae.differential_accessibility(groupby="labels", group1="label_1")
Esempio n. 2
0
def test_peakvi_online_update(save_path):
    n_latent = 5
    adata1 = synthetic_iid()
    PEAKVI.setup_anndata(adata1, batch_key="batch", labels_key="labels")
    model = PEAKVI(adata1, n_latent=n_latent)
    model.train(1, save_best=False)
    dir_path = os.path.join(save_path, "saved_model/")
    model.save(dir_path, overwrite=True)

    # also test subset var option
    adata2 = synthetic_iid()
    adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories(
        ["batch_2", "batch_3"])

    model2 = PEAKVI.load_query_data(adata2, dir_path)
    model2.train(max_epochs=1, weight_decay=0.0, save_best=False)
    model2.get_latent_representation()

    # encoder linear layer equal for peak features
    one = (model.module.z_encoder.encoder.fc_layers[0]
           [0].weight.detach().cpu().numpy()[:, :adata1.shape[1]])
    two = (model2.module.z_encoder.encoder.fc_layers[0]
           [0].weight.detach().cpu().numpy()[:, :adata1.shape[1]])
    np.testing.assert_equal(one, two)
    assert (np.sum(model2.module.z_encoder.encoder.fc_layers[0]
                   [0].weight.grad.cpu().numpy()[:, :adata1.shape[1]]) == 0)

    # test options
    adata1 = synthetic_iid()
    PEAKVI.setup_anndata(adata1, batch_key="batch", labels_key="labels")
    model = PEAKVI(
        adata1,
        n_latent=n_latent,
        encode_covariates=True,
    )
    model.train(1, check_val_every_n_epoch=1, save_best=False)
    dir_path = os.path.join(save_path, "saved_model/")
    model.save(dir_path, overwrite=True)

    adata2 = synthetic_iid()
    adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories(
        ["batch_2", "batch_3"])

    model2 = PEAKVI.load_query_data(adata2, dir_path, freeze_expression=True)
    model2.train(max_epochs=1, weight_decay=0.0, save_best=False)
    # deactivate no grad decorator
    model2.get_latent_representation()
    # pytorch lightning zeros the grad, so this will get a grad to inspect
    single_pass_for_online_update(model2)
    grad = model2.module.z_encoder.encoder.fc_layers[0][0].weight.grad.cpu(
    ).numpy()
    # expression part has zero grad
    assert np.sum(grad[:, :-4]) == 0
    # categorical part has non-zero grad
    assert np.count_nonzero(grad[:, -4:]) > 0

    # do not freeze expression
    model3 = PEAKVI.load_query_data(
        adata2,
        dir_path,
        freeze_expression=False,
        freeze_decoder_first_layer=False,
    )
    model3.train(max_epochs=1, save_best=False, weight_decay=0.0)
    model3.get_latent_representation()
    single_pass_for_online_update(model3)
    grad = model3.module.z_encoder.encoder.fc_layers[0][0].weight.grad.cpu(
    ).numpy()
    # linear layer weight in encoder layer has non-zero grad
    assert np.count_nonzero(grad[:, :-4]) != 0
    grad = model3.module.z_decoder.px_decoder.fc_layers[0][0].weight.grad.cpu(
    ).numpy()
    # linear layer weight in decoder layer has non-zero grad
    assert np.count_nonzero(grad[:, :-4]) != 0