Exemplo n.º 1
0
def test_autozi():
    data = synthetic_iid(n_batches=1)

    for disp_zi in ["gene", "gene-label"]:
        autozivae = AUTOZI(
            data,
            dispersion=disp_zi,
            zero_inflation=disp_zi,
        )
        autozivae.train(1, lr=1e-2)
        autozivae.get_elbo(indices=autozivae.test_indices)
        autozivae.get_reconstruction_error(indices=autozivae.test_indices)
        autozivae.get_marginal_ll(indices=autozivae.test_indices)
        autozivae.get_alphas_betas()
Exemplo n.º 2
0
def test_autozi():
    data = synthetic_iid(n_batches=1)
    for disp_zi in ["gene", "gene-label"]:
        autozivae = AUTOZI(
            data,
            dispersion=disp_zi,
            zero_inflation=disp_zi,
        )
        autozivae.train(1, plan_kwargs=dict(lr=1e-2), check_val_every_n_epoch=1)
        assert len(autozivae.history["elbo_train"]) == 1
        assert len(autozivae.history["elbo_validation"]) == 1
        autozivae.get_elbo(indices=autozivae.validation_indices)
        autozivae.get_reconstruction_error(indices=autozivae.validation_indices)
        autozivae.get_marginal_ll(indices=autozivae.validation_indices, n_mc_samples=3)
        autozivae.get_alphas_betas()
Exemplo n.º 3
0
def test_autozi():
    data = synthetic_iid(n_batches=1)

    for disp_zi in ["gene", "gene-label"]:
        autozivae = AUTOZI(
            data,
            dispersion=disp_zi,
            zero_inflation=disp_zi,
        )
        autozivae.train(1, lr=1e-2, frequency=1)
        assert len(autozivae.history["elbo_train_set"]) == 2
        assert len(autozivae.history["elbo_test_set"]) == 2
        autozivae.get_elbo(indices=autozivae.test_indices)
        autozivae.get_reconstruction_error(indices=autozivae.test_indices)
        autozivae.get_marginal_ll(indices=autozivae.test_indices)
        autozivae.get_alphas_betas()
Exemplo n.º 4
0
def test_autozi():
    data = synthetic_iid(n_batches=1, run_setup_anndata=False)
    AUTOZI.setup_anndata(
        data,
        batch_key="batch",
        labels_key="labels",
    )

    for disp_zi in ["gene", "gene-label"]:
        autozivae = AUTOZI(
            data,
            dispersion=disp_zi,
            zero_inflation=disp_zi,
        )
        autozivae.train(1,
                        plan_kwargs=dict(lr=1e-2),
                        check_val_every_n_epoch=1)
        assert len(autozivae.history["elbo_train"]) == 1
        assert len(autozivae.history["elbo_validation"]) == 1
        autozivae.get_elbo(indices=autozivae.validation_indices)
        autozivae.get_reconstruction_error(
            indices=autozivae.validation_indices)
        autozivae.get_marginal_ll(indices=autozivae.validation_indices,
                                  n_mc_samples=3)
        autozivae.get_alphas_betas()

    # Model library size.
    for disp_zi in ["gene", "gene-label"]:
        autozivae = AUTOZI(
            data,
            dispersion=disp_zi,
            zero_inflation=disp_zi,
            use_observed_lib_size=False,
        )
        autozivae.train(1,
                        plan_kwargs=dict(lr=1e-2),
                        check_val_every_n_epoch=1)
        assert hasattr(autozivae.module, "library_log_means") and hasattr(
            autozivae.module, "library_log_vars")
        assert len(autozivae.history["elbo_train"]) == 1
        assert len(autozivae.history["elbo_validation"]) == 1
        autozivae.get_elbo(indices=autozivae.validation_indices)
        autozivae.get_reconstruction_error(
            indices=autozivae.validation_indices)
        autozivae.get_marginal_ll(indices=autozivae.validation_indices,
                                  n_mc_samples=3)
        autozivae.get_alphas_betas()