예제 #1
0
def test_autozi():
    data = synthetic_iid(n_batches=1, run_setup_anndata=False)
    AUTOZI.setup_anndata(
        data,
        batch_key="batch",
        labels_key="labels",
    )

    for disp_zi in ["gene", "gene-label"]:
        autozivae = AUTOZI(
            data,
            dispersion=disp_zi,
            zero_inflation=disp_zi,
        )
        autozivae.train(1,
                        plan_kwargs=dict(lr=1e-2),
                        check_val_every_n_epoch=1)
        assert len(autozivae.history["elbo_train"]) == 1
        assert len(autozivae.history["elbo_validation"]) == 1
        autozivae.get_elbo(indices=autozivae.validation_indices)
        autozivae.get_reconstruction_error(
            indices=autozivae.validation_indices)
        autozivae.get_marginal_ll(indices=autozivae.validation_indices,
                                  n_mc_samples=3)
        autozivae.get_alphas_betas()

    # Model library size.
    for disp_zi in ["gene", "gene-label"]:
        autozivae = AUTOZI(
            data,
            dispersion=disp_zi,
            zero_inflation=disp_zi,
            use_observed_lib_size=False,
        )
        autozivae.train(1,
                        plan_kwargs=dict(lr=1e-2),
                        check_val_every_n_epoch=1)
        assert hasattr(autozivae.module, "library_log_means") and hasattr(
            autozivae.module, "library_log_vars")
        assert len(autozivae.history["elbo_train"]) == 1
        assert len(autozivae.history["elbo_validation"]) == 1
        autozivae.get_elbo(indices=autozivae.validation_indices)
        autozivae.get_reconstruction_error(
            indices=autozivae.validation_indices)
        autozivae.get_marginal_ll(indices=autozivae.validation_indices,
                                  n_mc_samples=3)
        autozivae.get_alphas_betas()
예제 #2
0
def test_saving_and_loading(save_path):
    def legacy_save(
        model,
        dir_path,
        prefix=None,
        overwrite=False,
        save_anndata=False,
        **anndata_write_kwargs,
    ):
        if not os.path.exists(dir_path) or overwrite:
            os.makedirs(dir_path, exist_ok=overwrite)
        else:
            raise ValueError(
                "{} already exists. Please provide an unexisting directory for saving."
                .format(dir_path))

        file_name_prefix = prefix or ""

        if save_anndata:
            model.adata.write(
                os.path.join(dir_path, f"{file_name_prefix}adata.h5ad"),
                **anndata_write_kwargs,
            )

        model_save_path = os.path.join(dir_path,
                                       f"{file_name_prefix}model_params.pt")
        attr_save_path = os.path.join(dir_path, f"{file_name_prefix}attr.pkl")
        varnames_save_path = os.path.join(dir_path,
                                          f"{file_name_prefix}var_names.csv")

        torch.save(model.module.state_dict(), model_save_path)

        var_names = model.adata.var_names.astype(str)
        var_names = var_names.to_numpy()
        np.savetxt(varnames_save_path, var_names, fmt="%s")

        # get all the user attributes
        user_attributes = model._get_user_attributes()
        # only save the public attributes with _ at the very end
        user_attributes = {
            a[0]: a[1]
            for a in user_attributes if a[0][-1] == "_"
        }

        with open(attr_save_path, "wb") as f:
            pickle.dump(user_attributes, f)

    def test_save_load_model(cls, adata, save_path, prefix=None, legacy=False):
        if cls is TOTALVI:
            cls.setup_anndata(
                adata,
                batch_key="batch",
                labels_key="labels",
                protein_expression_obsm_key="protein_expression",
                protein_names_uns_key="protein_names",
            )
        else:
            cls.setup_anndata(adata, batch_key="batch", labels_key="labels")
        model = cls(adata, latent_distribution="normal")
        model.train(1, train_size=0.2)
        z1 = model.get_latent_representation(adata)
        test_idx1 = model.validation_indices
        if legacy:
            legacy_save(model,
                        save_path,
                        overwrite=True,
                        save_anndata=True,
                        prefix=prefix)
        else:
            model.save(save_path,
                       overwrite=True,
                       save_anndata=True,
                       prefix=prefix)
        model = cls.load(save_path, prefix=prefix)
        model.get_latent_representation()

        # Load with mismatched genes.
        tmp_adata = synthetic_iid(n_genes=200, )
        with pytest.raises(ValueError):
            cls.load(save_path, adata=tmp_adata, prefix=prefix)

        # Load with different batches.
        tmp_adata = synthetic_iid()
        tmp_adata.obs["batch"] = tmp_adata.obs["batch"].cat.rename_categories(
            ["batch_2", "batch_3"])
        with pytest.raises(ValueError):
            cls.load(save_path, adata=tmp_adata, prefix=prefix)

        model = cls.load(save_path, adata=adata, prefix=prefix)
        assert "batch" in model.adata_manager.data_registry
        assert model.adata_manager.data_registry["batch"] == dict(
            attr_name="obs", attr_key="_scvi_batch")

        z2 = model.get_latent_representation()
        test_idx2 = model.validation_indices
        np.testing.assert_array_equal(z1, z2)
        np.testing.assert_array_equal(test_idx1, test_idx2)
        assert model.is_trained is True

    save_path = os.path.join(save_path, "tmp")
    adata = synthetic_iid()

    for cls in [SCVI, LinearSCVI, TOTALVI, PEAKVI]:
        test_save_load_model(cls,
                             adata,
                             save_path,
                             prefix=f"{cls.__name__}_",
                             legacy=True)
        test_save_load_model(cls, adata, save_path, prefix=f"{cls.__name__}_")
        # Test load prioritizes newer save paradigm and thus mismatches legacy save.
        with pytest.raises(AssertionError):
            test_save_load_model(cls,
                                 adata,
                                 save_path,
                                 prefix=f"{cls.__name__}_",
                                 legacy=True)

    # AUTOZI
    def test_save_load_autozi(legacy=False):
        prefix = "AUTOZI_"
        model = AUTOZI(adata, latent_distribution="normal")
        model.train(1, train_size=0.5)
        ab1 = model.get_alphas_betas()
        if legacy:
            legacy_save(model,
                        save_path,
                        overwrite=True,
                        save_anndata=True,
                        prefix=prefix)
        else:
            model.save(save_path,
                       overwrite=True,
                       save_anndata=True,
                       prefix=prefix)
        model = AUTOZI.load(save_path, prefix=prefix)
        model.get_latent_representation()
        tmp_adata = scvi.data.synthetic_iid(n_genes=200)
        with pytest.raises(ValueError):
            AUTOZI.load(save_path, adata=tmp_adata, prefix=prefix)
        model = AUTOZI.load(save_path, adata=adata, prefix=prefix)
        assert "batch" in model.adata_manager.data_registry
        assert model.adata_manager.data_registry["batch"] == dict(
            attr_name="obs", attr_key="_scvi_batch")

        ab2 = model.get_alphas_betas()
        np.testing.assert_array_equal(ab1["alpha_posterior"],
                                      ab2["alpha_posterior"])
        np.testing.assert_array_equal(ab1["beta_posterior"],
                                      ab2["beta_posterior"])
        assert model.is_trained is True

    AUTOZI.setup_anndata(adata, batch_key="batch", labels_key="labels")
    test_save_load_autozi(legacy=True)
    test_save_load_autozi()
    # Test load prioritizes newer save paradigm and thus mismatches legacy save.
    with pytest.raises(AssertionError):
        test_save_load_autozi(legacy=True)

    # SCANVI
    def test_save_load_scanvi(legacy=False):
        prefix = "SCANVI_"
        model = SCANVI(adata)
        model.train(max_epochs=1, train_size=0.5)
        p1 = model.predict()
        if legacy:
            legacy_save(model,
                        save_path,
                        overwrite=True,
                        save_anndata=True,
                        prefix=prefix)
        else:
            model.save(save_path,
                       overwrite=True,
                       save_anndata=True,
                       prefix=prefix)
        model = SCANVI.load(save_path, prefix=prefix)
        model.get_latent_representation()
        tmp_adata = scvi.data.synthetic_iid(n_genes=200)
        with pytest.raises(ValueError):
            SCANVI.load(save_path, adata=tmp_adata, prefix=prefix)
        model = SCANVI.load(save_path, adata=adata, prefix=prefix)
        assert "batch" in model.adata_manager.data_registry
        assert model.adata_manager.data_registry["batch"] == dict(
            attr_name="obs", attr_key="_scvi_batch")

        p2 = model.predict()
        np.testing.assert_array_equal(p1, p2)
        assert model.is_trained is True

    SCANVI.setup_anndata(adata,
                         "label_0",
                         batch_key="batch",
                         labels_key="labels")
    test_save_load_scanvi(legacy=True)
    test_save_load_scanvi()
    # Test load prioritizes newer save paradigm and thus mismatches legacy save.
    with pytest.raises(AssertionError):
        test_save_load_scanvi(legacy=True)