Example #1
0
def test_saving_and_loading(save_path):
    def test_save_load_model(cls, adata, save_path):
        model = cls(adata, latent_distribution="normal")
        model.train(1, train_size=0.2)
        z1 = model.get_latent_representation(adata)
        test_idx1 = model.validation_indices
        model.save(save_path, overwrite=True, save_anndata=True)
        model = cls.load(save_path)
        model.get_latent_representation()
        tmp_adata = scvi.data.synthetic_iid(n_genes=200)
        with pytest.raises(ValueError):
            cls.load(save_path, tmp_adata)
        model = cls.load(save_path, adata)
        z2 = model.get_latent_representation()
        test_idx2 = model.validation_indices
        np.testing.assert_array_equal(z1, z2)
        np.testing.assert_array_equal(test_idx1, test_idx2)
        assert model.is_trained is True

    save_path = os.path.join(save_path, "tmp")
    adata = synthetic_iid()

    for cls in [SCVI, LinearSCVI, TOTALVI]:
        print(cls)
        test_save_load_model(cls, adata, save_path)

    # AUTOZI
    model = AUTOZI(adata, latent_distribution="normal")
    model.train(1, train_size=0.5)
    ab1 = model.get_alphas_betas()
    model.save(save_path, overwrite=True, save_anndata=True)
    model = AUTOZI.load(save_path)
    model.get_latent_representation()
    tmp_adata = scvi.data.synthetic_iid(n_genes=200)
    with pytest.raises(ValueError):
        AUTOZI.load(save_path, tmp_adata)
    model = AUTOZI.load(save_path, adata)
    ab2 = model.get_alphas_betas()
    np.testing.assert_array_equal(ab1["alpha_posterior"],
                                  ab2["alpha_posterior"])
    np.testing.assert_array_equal(ab1["beta_posterior"], ab2["beta_posterior"])
    assert model.is_trained is True

    # SCANVI
    model = SCANVI(adata, "label_0")
    model.train(max_epochs=1, train_size=0.5)
    p1 = model.predict()
    model.save(save_path, overwrite=True, save_anndata=True)
    model = SCANVI.load(save_path)
    model.get_latent_representation()
    tmp_adata = scvi.data.synthetic_iid(n_genes=200)
    with pytest.raises(ValueError):
        SCANVI.load(save_path, tmp_adata)
    model = SCANVI.load(save_path, adata)
    p2 = model.predict()
    np.testing.assert_array_equal(p1, p2)
    assert model.is_trained is True
Example #2
0
def test_saving_and_loading(save_path):
    def test_save_load_model(cls, adata, save_path):
        model = cls(adata, latent_distribution="normal")
        model.train(1)
        z1 = model.get_latent_representation(adata)
        test_idx1 = model.test_indices
        model.save(save_path, overwrite=True)
        model = cls.load(adata, save_path)
        z2 = model.get_latent_representation()
        test_idx2 = model.test_indices
        np.testing.assert_array_equal(z1, z2)
        np.testing.assert_array_equal(test_idx1, test_idx2)
        assert model.is_trained is True

    save_path = os.path.join(save_path, "tmp")
    adata = synthetic_iid()

    for cls in [SCVI, LinearSCVI, TOTALVI]:
        print(cls)
        test_save_load_model(cls, adata, save_path)

    # AUTOZI
    model = AUTOZI(adata, latent_distribution="normal")
    model.train(1)
    ab1 = model.get_alphas_betas()
    model.save(save_path, overwrite=True)
    model = AUTOZI.load(adata, save_path)
    ab2 = model.get_alphas_betas()
    np.testing.assert_array_equal(ab1["alpha_posterior"],
                                  ab2["alpha_posterior"])
    np.testing.assert_array_equal(ab1["beta_posterior"], ab2["beta_posterior"])
    assert model.is_trained is True

    # SCANVI
    model = SCANVI(adata, "undefined_0")
    model.train(n_epochs_unsupervised=1, n_epochs_semisupervised=1)
    p1 = model.predict()
    model.save(save_path, overwrite=True)
    model = SCANVI.load(adata, save_path)
    p2 = model.predict()
    np.testing.assert_array_equal(p1, p2)
    assert model.is_trained is True

    # GIMVI
    model = GIMVI(adata, adata)
    model.train(1)
    z1 = model.get_latent_representation([adata])
    z2 = model.get_latent_representation([adata])
    np.testing.assert_array_equal(z1, z2)
    model.save(save_path, overwrite=True)
    model = GIMVI.load(adata, adata, save_path)
    z2 = model.get_latent_representation([adata])
    np.testing.assert_array_equal(z1, z2)
    assert model.is_trained is True
Example #3
0
def test_multiple_covariates_scvi(save_path):
    adata = synthetic_iid()
    adata.obs["cont1"] = np.random.normal(size=(adata.shape[0], ))
    adata.obs["cont2"] = np.random.normal(size=(adata.shape[0], ))
    adata.obs["cat1"] = np.random.randint(0, 5, size=(adata.shape[0], ))
    adata.obs["cat2"] = np.random.randint(0, 5, size=(adata.shape[0], ))

    SCVI.setup_anndata(
        adata,
        batch_key="batch",
        labels_key="labels",
        continuous_covariate_keys=["cont1", "cont2"],
        categorical_covariate_keys=["cat1", "cat2"],
    )
    m = SCVI(adata)
    m.train(1)

    m = SCANVI(adata, unlabeled_category="Unknown")
    m.train(1)

    TOTALVI.setup_anndata(
        adata,
        batch_key="batch",
        protein_expression_obsm_key="protein_expression",
        protein_names_uns_key="protein_names",
        continuous_covariate_keys=["cont1", "cont2"],
        categorical_covariate_keys=["cat1", "cat2"],
    )
    m = TOTALVI(adata)
    m.train(1)
Example #4
0
def test_scanvi(save_path):
    adata = synthetic_iid()
    model = SCANVI(adata, "label_0", n_latent=10)
    model.train(1, train_size=0.5, check_val_every_n_epoch=1)
    logged_keys = model.history.keys()
    assert "elbo_validation" in logged_keys
    assert "reconstruction_loss_validation" in logged_keys
    assert "kl_local_validation" in logged_keys
    assert "elbo_train" in logged_keys
    assert "reconstruction_loss_train" in logged_keys
    assert "kl_local_train" in logged_keys
    assert "classification_loss_validation" in logged_keys
    adata2 = synthetic_iid()
    predictions = model.predict(adata2, indices=[1, 2, 3])
    assert len(predictions) == 3
    model.predict()
    model.predict(adata2, soft=True)
    model.predict(adata2, soft=True, indices=[1, 2, 3])
    model.get_normalized_expression(adata2)
    model.differential_expression(groupby="labels", group1="label_1")
    model.differential_expression(groupby="labels", group1="label_1", group2="label_2")

    # test that all data labeled runs
    unknown_label = "asdf"
    a = scvi.data.synthetic_iid()
    scvi.data.setup_anndata(a, batch_key="batch", labels_key="labels")
    m = scvi.model.SCANVI(a, unknown_label)
    m.train(1)

    # test mix of labeled and unlabeled data
    unknown_label = "label_0"
    a = scvi.data.synthetic_iid()
    scvi.data.setup_anndata(a, batch_key="batch", labels_key="labels")
    m = scvi.model.SCANVI(a, unknown_label)
    m.train(1, train_size=0.9)

    # test from_scvi_model
    a = scvi.data.synthetic_iid()
    m = scvi.model.SCVI(a, use_observed_lib_size=False)
    a2 = scvi.data.synthetic_iid()
    scanvi_model = scvi.model.SCANVI.from_scvi_model(m, "label_0", adata=a2)
    scanvi_model = scvi.model.SCANVI.from_scvi_model(
        m, "label_0", use_labels_groups=False
    )
    scanvi_model.train(1)
Example #5
0
def test_scanvi():
    adata = synthetic_iid()
    model = SCANVI(adata, "label_0", n_latent=10)
    model.train(1, train_size=0.5, frequency=1)
    assert len(model.history["unsupervised_trainer_history"]) == 2
    assert len(model.history["semisupervised_trainer_history"]) == 3
    adata2 = synthetic_iid()
    predictions = model.predict(adata2, indices=[1, 2, 3])
    assert len(predictions) == 3
    model.predict()
    model.predict(adata2, soft=True)
    model.predict(adata2, soft=True, indices=[1, 2, 3])
    model.get_normalized_expression(adata2)
    model.differential_expression(groupby="labels", group1="label_1")
    model.differential_expression(groupby="labels",
                                  group1="label_1",
                                  group2="label_2")
Example #6
0
def test_scanvi_online_update(save_path):
    # ref has semi-observed labels
    n_latent = 5
    adata1 = synthetic_iid(run_setup_anndata=False)
    new_labels = adata1.obs.labels.to_numpy()
    new_labels[0] = "Unknown"
    adata1.obs["labels"] = pd.Categorical(new_labels)
    setup_anndata(adata1, batch_key="batch", labels_key="labels")
    model = SCANVI(adata1, "Unknown", n_latent=n_latent, encode_covariates=True)
    model.train(max_epochs=1, check_val_every_n_epoch=1)
    dir_path = os.path.join(save_path, "saved_model/")
    model.save(dir_path, overwrite=True)

    adata2 = synthetic_iid(run_setup_anndata=False)
    adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories(["batch_2", "batch_3"])
    adata2.obs["labels"] = "Unknown"

    model = SCANVI.load_query_data(adata2, dir_path, freeze_batchnorm_encoder=True)
    model.train(max_epochs=1)
    model.get_latent_representation()
    model.predict()

    # ref has fully-observed labels
    n_latent = 5
    adata1 = synthetic_iid(run_setup_anndata=False)
    new_labels = adata1.obs.labels.to_numpy()
    adata1.obs["labels"] = pd.Categorical(new_labels)
    setup_anndata(adata1, batch_key="batch", labels_key="labels")
    model = SCANVI(adata1, "Unknown", n_latent=n_latent, encode_covariates=True)
    model.train(max_epochs=1, check_val_every_n_epoch=1)
    dir_path = os.path.join(save_path, "saved_model/")
    model.save(dir_path, overwrite=True)

    # query has one new label
    adata2 = synthetic_iid(run_setup_anndata=False)
    adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories(["batch_2", "batch_3"])
    new_labels = adata2.obs.labels.to_numpy()
    new_labels[0] = "Unknown"
    adata2.obs["labels"] = pd.Categorical(new_labels)

    model2 = SCANVI.load_query_data(adata2, dir_path, freeze_batchnorm_encoder=True)
    model2._unlabeled_indices = np.arange(adata2.n_obs)
    model2._labeled_indices = []
    model2.train(max_epochs=1, plan_kwargs=dict(weight_decay=0.0))
    model2.get_latent_representation()
    model2.predict()

    # test classifier frozen
    class_query_weight = (
        model2.module.classifier.classifier[0]
        .fc_layers[0][0]
        .weight.detach()
        .cpu()
        .numpy()
    )
    class_ref_weight = (
        model.module.classifier.classifier[0]
        .fc_layers[0][0]
        .weight.detach()
        .cpu()
        .numpy()
    )
    # weight decay makes difference
    np.testing.assert_allclose(class_query_weight, class_ref_weight, atol=1e-07)

    # test classifier unfrozen
    model2 = SCANVI.load_query_data(adata2, dir_path, freeze_classifier=False)
    model2._unlabeled_indices = np.arange(adata2.n_obs)
    model2._labeled_indices = []
    model2.train(max_epochs=1)
    class_query_weight = (
        model2.module.classifier.classifier[0]
        .fc_layers[0][0]
        .weight.detach()
        .cpu()
        .numpy()
    )
    class_ref_weight = (
        model.module.classifier.classifier[0]
        .fc_layers[0][0]
        .weight.detach()
        .cpu()
        .numpy()
    )
    with pytest.raises(AssertionError):
        np.testing.assert_allclose(class_query_weight, class_ref_weight, atol=1e-07)

    # test saving and loading of online scanvi
    a = scvi.data.synthetic_iid(run_setup_anndata=False)
    ref = a[a.obs["labels"] != "label_2"].copy()  # only has labels 0 and 1
    scvi.data.setup_anndata(ref, batch_key="batch", labels_key="labels")
    m = SCANVI(ref, "label_2")
    m.train(max_epochs=1)
    m.save(save_path, overwrite=True)
    query = a[a.obs["labels"] != "label_0"].copy()
    query = scvi.data.synthetic_iid()  # has labels 0 and 2. 2 is unknown
    m_q = SCANVI.load_query_data(query, save_path)
    m_q.save(save_path, overwrite=True)
    m_q = SCANVI.load(save_path, query)
    m_q.predict()
    m_q.get_elbo()
Example #7
0
def test_scanvi(save_path):
    adata = synthetic_iid()
    model = SCANVI(adata, "label_0", n_latent=10)
    model.train(1, train_size=0.5, check_val_every_n_epoch=1)
    logged_keys = model.history.keys()
    assert "elbo_validation" in logged_keys
    assert "reconstruction_loss_validation" in logged_keys
    assert "kl_local_validation" in logged_keys
    assert "elbo_train" in logged_keys
    assert "reconstruction_loss_train" in logged_keys
    assert "kl_local_train" in logged_keys
    adata2 = synthetic_iid()
    predictions = model.predict(adata2, indices=[1, 2, 3])
    assert len(predictions) == 3
    model.predict()
    model.predict(adata2, soft=True)
    model.predict(adata2, soft=True, indices=[1, 2, 3])
    model.get_normalized_expression(adata2)
    model.differential_expression(groupby="labels", group1="label_1")
    model.differential_expression(groupby="labels", group1="label_1", group2="label_2")

    # test that all data labeled runs
    unknown_label = "asdf"
    a = scvi.data.synthetic_iid()
    scvi.data.setup_anndata(a, batch_key="batch", labels_key="labels")
    m = scvi.model.SCANVI(a, unknown_label)
    m.train(1)

    # check the number of indices
    n_train_idx = len(m.train_indices)
    n_validation_idx = len(m.validation_indices)
    n_test_idx = len(m.test_indices)
    assert n_train_idx + n_validation_idx + n_test_idx == a.n_obs
    assert np.isclose(n_train_idx / a.n_obs, 0.9)
    assert np.isclose(n_validation_idx / a.n_obs, 0.1)
    assert np.isclose(n_test_idx / a.n_obs, 0)

    # test mix of labeled and unlabeled data
    unknown_label = "label_0"
    a = scvi.data.synthetic_iid()
    scvi.data.setup_anndata(a, batch_key="batch", labels_key="labels")
    m = scvi.model.SCANVI(a, unknown_label)
    m.train(1, train_size=0.9)
    # check the number of indices
    n_train_idx = len(m.train_indices)
    n_validation_idx = len(m.validation_indices)
    n_test_idx = len(m.test_indices)
    assert n_train_idx + n_validation_idx + n_test_idx == a.n_obs
    assert np.isclose(n_train_idx / a.n_obs, 0.9, rtol=0.05)
    assert np.isclose(n_validation_idx / a.n_obs, 0.1, rtol=0.05)
    assert np.isclose(n_test_idx / a.n_obs, 0, rtol=0.05)

    # check that training indices have proper mix of labeled and unlabeled data
    labelled_idx = np.where(a.obs["labels"] != unknown_label)[0]
    unlabelled_idx = np.where(a.obs["labels"] == unknown_label)[0]
    # labeled training idx
    labeled_train_idx = [i for i in m.train_indices if i in labelled_idx]
    # unlabeled training idx
    unlabeled_train_idx = [i for i in m.train_indices if i in unlabelled_idx]
    n_labeled_idx = len(m._labeled_indices)
    n_unlabeled_idx = len(m._unlabeled_indices)
    # labeled vs unlabeled ratio in adata
    adata_ratio = n_unlabeled_idx / n_labeled_idx
    # labeled vs unlabeled ratio in train set
    train_ratio = len(unlabeled_train_idx) / len(labeled_train_idx)
    assert np.isclose(adata_ratio, train_ratio, atol=0.05)
Example #8
0
def test_scanvi():
    adata = synthetic_iid()
    model = SCANVI(adata, "undefined_0", n_latent=10)
    model.train(1)
    adata2 = synthetic_iid()
    predictions = model.predict(adata2, indices=[1, 2, 3])
    assert len(predictions) == 3
    model.predict()
    model.predict(adata2, soft=True)
    model.predict(adata2, soft=True, indices=[1, 2, 3])
    model.get_normalized_expression(adata2)
    model.differential_expression(groupby="labels", group1="undefined_1")
    model.differential_expression(groupby="labels",
                                  group1="undefined_1",
                                  group2="undefined_2")
Example #9
0
    def test_save_load_scanvi(legacy=False):
        prefix = "SCANVI_"
        model = SCANVI(adata, "label_0")
        model.train(max_epochs=1, train_size=0.5)
        p1 = model.predict()
        if legacy:
            legacy_save(model,
                        save_path,
                        overwrite=True,
                        save_anndata=True,
                        prefix=prefix)
        else:
            model.save(save_path,
                       overwrite=True,
                       save_anndata=True,
                       prefix=prefix)
        model = SCANVI.load(save_path, prefix=prefix)
        model.get_latent_representation()
        tmp_adata = scvi.data.synthetic_iid(n_genes=200)
        with pytest.raises(ValueError):
            SCANVI.load(save_path, adata=tmp_adata, prefix=prefix)
        model = SCANVI.load(save_path, adata=adata, prefix=prefix)
        assert "test" in adata.uns["_scvi"]["data_registry"]
        assert adata.uns["_scvi"]["data_registry"]["test"] == dict(
            attr_name="obs", attr_key="cont1")

        p2 = model.predict()
        np.testing.assert_array_equal(p1, p2)
        assert model.is_trained is True
Example #10
0
def test_scanvi_online_update(save_path):
    # ref has semi-observed labels
    n_latent = 5
    adata1 = synthetic_iid(run_setup_anndata=False)
    new_labels = adata1.obs.labels.to_numpy()
    new_labels[0] = "Unknown"
    adata1.obs["labels"] = pd.Categorical(new_labels)
    setup_anndata(adata1, batch_key="batch", labels_key="labels")
    model = SCANVI(adata1, "Unknown", n_latent=n_latent, encode_covariates=True)
    model.train(n_epochs_unsupervised=1, n_epochs_semisupervised=1, frequency=1)
    dir_path = os.path.join(save_path, "saved_model/")
    model.save(dir_path, overwrite=True)

    adata2 = synthetic_iid(run_setup_anndata=False)
    adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories(["batch_2", "batch_3"])
    adata2.obs["labels"] = "Unknown"

    model = SCANVI.load_query_data(adata2, dir_path, freeze_batchnorm_encoder=True)
    model.train(
        n_epochs_unsupervised=1, n_epochs_semisupervised=1, train_base_model=False
    )
    model.get_latent_representation()
    model.predict()

    # ref has fully-observed labels
    n_latent = 5
    adata1 = synthetic_iid(run_setup_anndata=False)
    new_labels = adata1.obs.labels.to_numpy()
    adata1.obs["labels"] = pd.Categorical(new_labels)
    setup_anndata(adata1, batch_key="batch", labels_key="labels")
    model = SCANVI(adata1, "Unknown", n_latent=n_latent, encode_covariates=True)
    model.train(n_epochs_unsupervised=1, n_epochs_semisupervised=1, frequency=1)
    dir_path = os.path.join(save_path, "saved_model/")
    model.save(dir_path, overwrite=True)

    # query has one new label
    adata2 = synthetic_iid(run_setup_anndata=False)
    adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories(["batch_2", "batch_3"])
    new_labels = adata2.obs.labels.to_numpy()
    new_labels[0] = "Unknown"
    adata2.obs["labels"] = pd.Categorical(new_labels)

    model = SCANVI.load_query_data(adata2, dir_path, freeze_batchnorm_encoder=True)
    model._unlabeled_indices = np.arange(adata2.n_obs)
    model._labeled_indices = []
    model.train(
        n_epochs_unsupervised=1, n_epochs_semisupervised=1, train_base_model=False
    )
    model.get_latent_representation()
    model.predict()
Example #11
0
def test_scanvi(save_path):
    adata = synthetic_iid()
    SCANVI.setup_anndata(
        adata,
        "label_0",
        batch_key="batch",
        labels_key="labels",
    )
    model = SCANVI(adata, n_latent=10)
    model.train(1, train_size=0.5, check_val_every_n_epoch=1)
    logged_keys = model.history.keys()
    assert "elbo_validation" in logged_keys
    assert "reconstruction_loss_validation" in logged_keys
    assert "kl_local_validation" in logged_keys
    assert "elbo_train" in logged_keys
    assert "reconstruction_loss_train" in logged_keys
    assert "kl_local_train" in logged_keys
    assert "classification_loss_validation" in logged_keys
    adata2 = synthetic_iid()
    predictions = model.predict(adata2, indices=[1, 2, 3])
    assert len(predictions) == 3
    model.predict()
    df = model.predict(adata2, soft=True)
    assert isinstance(df, pd.DataFrame)
    model.predict(adata2, soft=True, indices=[1, 2, 3])
    model.get_normalized_expression(adata2)
    model.differential_expression(groupby="labels", group1="label_1")
    model.differential_expression(groupby="labels",
                                  group1="label_1",
                                  group2="label_2")

    # test that all data labeled runs
    unknown_label = "asdf"
    a = scvi.data.synthetic_iid()
    scvi.model.SCANVI.setup_anndata(a,
                                    unknown_label,
                                    batch_key="batch",
                                    labels_key="labels")
    m = scvi.model.SCANVI(a)
    m.train(1)

    # test mix of labeled and unlabeled data
    unknown_label = "label_0"
    a = scvi.data.synthetic_iid()
    scvi.model.SCANVI.setup_anndata(a,
                                    unknown_label,
                                    batch_key="batch",
                                    labels_key="labels")
    m = scvi.model.SCANVI(a)
    m.train(1, train_size=0.9)

    # test from_scvi_model
    a = scvi.data.synthetic_iid()
    SCVI.setup_anndata(
        a,
        batch_key="batch",
        labels_key="labels",
    )
    m = SCVI(a, use_observed_lib_size=False)
    a2 = scvi.data.synthetic_iid()
    scanvi_model = scvi.model.SCANVI.from_scvi_model(m, "label_0", adata=a2)
    # make sure the state_dicts are different objects for the two models
    assert scanvi_model.module.state_dict() is not m.module.state_dict()
    scanvi_pxr = scanvi_model.module.state_dict().get("px_r", None)
    scvi_pxr = m.module.state_dict().get("px_r", None)
    assert scanvi_pxr is not None and scvi_pxr is not None
    assert scanvi_pxr is not scvi_pxr
    scanvi_model.train(1)

    # Test without label groups
    scanvi_model = scvi.model.SCANVI.from_scvi_model(m,
                                                     "label_0",
                                                     use_labels_groups=False)
    scanvi_model.train(1)

    # test from_scvi_model with size_factor
    a = scvi.data.synthetic_iid()
    a.obs["size_factor"] = np.random.randint(1, 5, size=(a.shape[0], ))
    SCVI.setup_anndata(a,
                       batch_key="batch",
                       labels_key="labels",
                       size_factor_key="size_factor")
    m = SCVI(a, use_observed_lib_size=False)
    a2 = scvi.data.synthetic_iid()
    a2.obs["size_factor"] = np.random.randint(1, 5, size=(a2.shape[0], ))
    scanvi_model = scvi.model.SCANVI.from_scvi_model(m, "label_0", adata=a2)
    scanvi_model.train(1)
Example #12
0
def test_saving_and_loading(save_path):
    def legacy_save(
        model,
        dir_path,
        prefix=None,
        overwrite=False,
        save_anndata=False,
        **anndata_write_kwargs,
    ):
        if not os.path.exists(dir_path) or overwrite:
            os.makedirs(dir_path, exist_ok=overwrite)
        else:
            raise ValueError(
                "{} already exists. Please provide an unexisting directory for saving."
                .format(dir_path))

        file_name_prefix = prefix or ""

        if save_anndata:
            model.adata.write(
                os.path.join(dir_path, f"{file_name_prefix}adata.h5ad"),
                **anndata_write_kwargs,
            )

        model_save_path = os.path.join(dir_path,
                                       f"{file_name_prefix}model_params.pt")
        attr_save_path = os.path.join(dir_path, f"{file_name_prefix}attr.pkl")
        varnames_save_path = os.path.join(dir_path,
                                          f"{file_name_prefix}var_names.csv")

        torch.save(model.module.state_dict(), model_save_path)

        var_names = model.adata.var_names.astype(str)
        var_names = var_names.to_numpy()
        np.savetxt(varnames_save_path, var_names, fmt="%s")

        # get all the user attributes
        user_attributes = model._get_user_attributes()
        # only save the public attributes with _ at the very end
        user_attributes = {
            a[0]: a[1]
            for a in user_attributes if a[0][-1] == "_"
        }

        with open(attr_save_path, "wb") as f:
            pickle.dump(user_attributes, f)

    def test_save_load_model(cls, adata, save_path, prefix=None, legacy=False):
        if cls is TOTALVI:
            cls.setup_anndata(
                adata,
                batch_key="batch",
                labels_key="labels",
                protein_expression_obsm_key="protein_expression",
                protein_names_uns_key="protein_names",
            )
        else:
            cls.setup_anndata(adata, batch_key="batch", labels_key="labels")
        model = cls(adata, latent_distribution="normal")
        model.train(1, train_size=0.2)
        z1 = model.get_latent_representation(adata)
        test_idx1 = model.validation_indices
        if legacy:
            legacy_save(model,
                        save_path,
                        overwrite=True,
                        save_anndata=True,
                        prefix=prefix)
        else:
            model.save(save_path,
                       overwrite=True,
                       save_anndata=True,
                       prefix=prefix)
        model = cls.load(save_path, prefix=prefix)
        model.get_latent_representation()

        # Load with mismatched genes.
        tmp_adata = synthetic_iid(n_genes=200, )
        with pytest.raises(ValueError):
            cls.load(save_path, adata=tmp_adata, prefix=prefix)

        # Load with different batches.
        tmp_adata = synthetic_iid()
        tmp_adata.obs["batch"] = tmp_adata.obs["batch"].cat.rename_categories(
            ["batch_2", "batch_3"])
        with pytest.raises(ValueError):
            cls.load(save_path, adata=tmp_adata, prefix=prefix)

        model = cls.load(save_path, adata=adata, prefix=prefix)
        assert "batch" in model.adata_manager.data_registry
        assert model.adata_manager.data_registry["batch"] == dict(
            attr_name="obs", attr_key="_scvi_batch")

        z2 = model.get_latent_representation()
        test_idx2 = model.validation_indices
        np.testing.assert_array_equal(z1, z2)
        np.testing.assert_array_equal(test_idx1, test_idx2)
        assert model.is_trained is True

    save_path = os.path.join(save_path, "tmp")
    adata = synthetic_iid()

    for cls in [SCVI, LinearSCVI, TOTALVI, PEAKVI]:
        test_save_load_model(cls,
                             adata,
                             save_path,
                             prefix=f"{cls.__name__}_",
                             legacy=True)
        test_save_load_model(cls, adata, save_path, prefix=f"{cls.__name__}_")
        # Test load prioritizes newer save paradigm and thus mismatches legacy save.
        with pytest.raises(AssertionError):
            test_save_load_model(cls,
                                 adata,
                                 save_path,
                                 prefix=f"{cls.__name__}_",
                                 legacy=True)

    # AUTOZI
    def test_save_load_autozi(legacy=False):
        prefix = "AUTOZI_"
        model = AUTOZI(adata, latent_distribution="normal")
        model.train(1, train_size=0.5)
        ab1 = model.get_alphas_betas()
        if legacy:
            legacy_save(model,
                        save_path,
                        overwrite=True,
                        save_anndata=True,
                        prefix=prefix)
        else:
            model.save(save_path,
                       overwrite=True,
                       save_anndata=True,
                       prefix=prefix)
        model = AUTOZI.load(save_path, prefix=prefix)
        model.get_latent_representation()
        tmp_adata = scvi.data.synthetic_iid(n_genes=200)
        with pytest.raises(ValueError):
            AUTOZI.load(save_path, adata=tmp_adata, prefix=prefix)
        model = AUTOZI.load(save_path, adata=adata, prefix=prefix)
        assert "batch" in model.adata_manager.data_registry
        assert model.adata_manager.data_registry["batch"] == dict(
            attr_name="obs", attr_key="_scvi_batch")

        ab2 = model.get_alphas_betas()
        np.testing.assert_array_equal(ab1["alpha_posterior"],
                                      ab2["alpha_posterior"])
        np.testing.assert_array_equal(ab1["beta_posterior"],
                                      ab2["beta_posterior"])
        assert model.is_trained is True

    AUTOZI.setup_anndata(adata, batch_key="batch", labels_key="labels")
    test_save_load_autozi(legacy=True)
    test_save_load_autozi()
    # Test load prioritizes newer save paradigm and thus mismatches legacy save.
    with pytest.raises(AssertionError):
        test_save_load_autozi(legacy=True)

    # SCANVI
    def test_save_load_scanvi(legacy=False):
        prefix = "SCANVI_"
        model = SCANVI(adata)
        model.train(max_epochs=1, train_size=0.5)
        p1 = model.predict()
        if legacy:
            legacy_save(model,
                        save_path,
                        overwrite=True,
                        save_anndata=True,
                        prefix=prefix)
        else:
            model.save(save_path,
                       overwrite=True,
                       save_anndata=True,
                       prefix=prefix)
        model = SCANVI.load(save_path, prefix=prefix)
        model.get_latent_representation()
        tmp_adata = scvi.data.synthetic_iid(n_genes=200)
        with pytest.raises(ValueError):
            SCANVI.load(save_path, adata=tmp_adata, prefix=prefix)
        model = SCANVI.load(save_path, adata=adata, prefix=prefix)
        assert "batch" in model.adata_manager.data_registry
        assert model.adata_manager.data_registry["batch"] == dict(
            attr_name="obs", attr_key="_scvi_batch")

        p2 = model.predict()
        np.testing.assert_array_equal(p1, p2)
        assert model.is_trained is True

    SCANVI.setup_anndata(adata,
                         "label_0",
                         batch_key="batch",
                         labels_key="labels")
    test_save_load_scanvi(legacy=True)
    test_save_load_scanvi()
    # Test load prioritizes newer save paradigm and thus mismatches legacy save.
    with pytest.raises(AssertionError):
        test_save_load_scanvi(legacy=True)