Beispiel #1
0
def test_extra_covariates_transfer():
    adata = synthetic_iid()
    adata.obs["cont1"] = np.random.normal(size=(adata.shape[0],))
    adata.obs["cont2"] = np.random.normal(size=(adata.shape[0],))
    adata.obs["cat1"] = np.random.randint(0, 5, size=(adata.shape[0],))
    adata.obs["cat2"] = np.random.randint(0, 5, size=(adata.shape[0],))
    setup_anndata(
        adata,
        batch_key="batch",
        labels_key="labels",
        protein_expression_obsm_key="protein_expression",
        protein_names_uns_key="protein_names",
        continuous_covariate_keys=["cont1", "cont2"],
        categorical_covariate_keys=["cat1", "cat2"],
    )
    bdata = synthetic_iid()
    bdata.obs["cont1"] = np.random.normal(size=(bdata.shape[0],))
    bdata.obs["cont2"] = np.random.normal(size=(bdata.shape[0],))
    bdata.obs["cat1"] = 0
    bdata.obs["cat2"] = 1

    transfer_anndata_setup(adata_source=adata, adata_target=bdata)

    # give it a new category
    del bdata.uns["_scvi"]
    bdata.obs["cat1"] = 6
    transfer_anndata_setup(
        adata_source=adata, adata_target=bdata, extend_categories=True
    )
    assert bdata.uns["_scvi"]["extra_categoricals"]["mappings"]["cat1"][-1] == 6
Beispiel #2
0
    def _validate_anndata(self,
                          adata: Optional[AnnData] = None,
                          copy_if_view: bool = True):
        """Validate anndata has been properly registered, transfer if necessary."""
        if adata is None:
            adata = self.adata
        if adata.is_view:
            if copy_if_view:
                logger.info("Received view of anndata, making copy.")
                adata = adata.copy()
            else:
                raise ValueError("Please run `adata = adata.copy()`")

        if "_scvi" not in adata.uns_keys():
            logger.info("Input adata not setup with scvi. " +
                        "attempting to transfer anndata setup")
            transfer_anndata_setup(self.scvi_setup_dict_, adata)
        is_nonneg_int = _check_nonnegative_integers(
            get_from_registry(adata, _CONSTANTS.X_KEY))
        if not is_nonneg_int:
            logger.warning(
                "Make sure the registered X field in anndata contains unnormalized count data."
            )

        _check_anndata_setup_equivalence(self.scvi_setup_dict_, adata)

        return adata
Beispiel #3
0
    def load(
        cls,
        dir_path: str,
        adata: Optional[AnnData] = None,
        use_gpu: Optional[bool] = None,
    ):
        """
        Instantiate a model from the saved output.

        Parameters
        ----------
        dir_path
            Path to saved outputs.
        adata
            AnnData organized in the same way as data used to train model.
            It is not necessary to run :func:`~scvi.data.setup_anndata`,
            as AnnData is validated against the saved `scvi` setup dictionary.
            If None, will check for and load anndata saved with the model.
        use_gpu
            Whether to load model on GPU.

        Returns
        -------
        Model with loaded state dictionaries.

        Examples
        --------
        >>> vae = SCVI.load(adata, save_path)
        >>> vae.get_latent_representation()
        """
        load_adata = adata is None
        if use_gpu is None:
            use_gpu = torch.cuda.is_available()
        map_location = torch.device("cpu") if use_gpu is False else None
        (
            scvi_setup_dict,
            attr_dict,
            var_names,
            model_state_dict,
            new_adata,
        ) = _load_saved_files(dir_path, load_adata, map_location=map_location)
        adata = new_adata if new_adata is not None else adata

        _validate_var_names(adata, var_names)
        transfer_anndata_setup(scvi_setup_dict, adata)
        model = _initialize_model(cls, adata, attr_dict, use_gpu)

        # set saved attrs for loaded model
        for attr, val in attr_dict.items():
            setattr(model, attr, val)

        model.module.load_state_dict(model_state_dict)
        if use_gpu:
            model.module.cuda()

        model.module.eval()
        model._validate_anndata(adata)

        return model
Beispiel #4
0
def _transfer_model(model, adata):
    adata = adata.copy()

    attr_dict = model._get_user_attributes()
    attr_dict = {a[0]: a[1] for a in attr_dict if a[0][-1] == "_"}
    scvi_setup_dict = attr_dict.pop("scvi_setup_dict_")
    transfer_anndata_setup(scvi_setup_dict, adata, extend_categories=True)

    adata.uns["_scvi"]["summary_stats"]["n_labels"] = scvi_setup_dict[
        "summary_stats"
    ]["n_labels"]

    new_model = _initialize_model(model.__class__, adata, attr_dict, use_cuda=True)

    for attr, val in attr_dict.items():
        setattr(new_model, attr, val)

    model.model.cuda()
    new_model.model.cuda()
    new_state_dict = model.model.state_dict()

    load_state_dict = model.model.state_dict().copy()
    new_state_dict = new_model.model.state_dict()
    for key, load_ten in load_state_dict.items():
        new_ten = new_state_dict[key]
        if new_ten.size() == load_ten.size():
            continue
        else:
            dim_diff = new_ten.size()[-1] - load_ten.size()[-1]
            fixed_ten = torch.cat([load_ten, new_ten[..., -dim_diff:]], dim=-1)
            load_state_dict[key] = fixed_ten
    new_model.model.load_state_dict(load_state_dict)
    new_model.model.eval()

    new_model.is_trained_ = False

    return new_model, adata
Beispiel #5
0
    def load(
        cls,
        dir_path: str,
        adata: Optional[AnnData] = None,
        use_cuda: bool = False,
    ):
        """
        Instantiate a model from the saved output.

        Parameters
        ----------
        dir_path
            Path to saved outputs.
        adata
            AnnData organized in the same way as data used to train model.
            It is not necessary to run :func:`~scvi.data.setup_anndata`,
            as AnnData is validated against the saved `scvi` setup dictionary.
            If None, will check for and load anndata saved with the model.
        use_cuda
            Whether to load model on GPU.

        Returns
        -------
        Model with loaded state dictionaries.

        Examples
        --------
        >>> vae = SCVI.load(adata, save_path)
        >>> vae.get_latent_representation()
        """
        model_path = os.path.join(dir_path, "model_params.pt")
        setup_dict_path = os.path.join(dir_path, "attr.pkl")
        adata_path = os.path.join(dir_path, "adata.h5ad")
        varnames_path = os.path.join(dir_path, "var_names.csv")

        if os.path.exists(adata_path) and adata is None:
            adata = read(adata_path)
        elif not os.path.exists(adata_path) and adata is None:
            raise ValueError(
                "Save path contains no saved anndata and no adata was passed.")
        var_names = np.genfromtxt(varnames_path, delimiter=",", dtype=str)
        user_var_names = adata.var_names.astype(str)
        if not np.array_equal(var_names, user_var_names):
            logger.warning(
                "var_names for adata passed in does not match var_names of "
                "adata used to train the model. For valid results, the vars "
                "need to be the same and in the same order as the adata used to train the model."
            )

        with open(setup_dict_path, "rb") as handle:
            attr_dict = pickle.load(handle)

        scvi_setup_dict = attr_dict.pop("scvi_setup_dict_")

        transfer_anndata_setup(scvi_setup_dict, adata)

        if "init_params_" not in attr_dict.keys():
            raise ValueError(
                "No init_params_ were saved by the model. Check out the "
                "developers guide if creating custom models.")
        # get the parameters for the class init signiture
        init_params = attr_dict.pop("init_params_")

        # update use_cuda from the saved model
        use_cuda = use_cuda and torch.cuda.is_available()
        init_params["use_cuda"] = use_cuda

        # grab all the parameters execept for kwargs (is a dict)
        non_kwargs = {
            k: v
            for k, v in init_params.items() if not isinstance(v, dict)
        }
        # expand out kwargs
        kwargs = {k: v for k, v in init_params.items() if isinstance(v, dict)}
        kwargs = {k: v for (i, j) in kwargs.items() for (k, v) in j.items()}
        model = cls(adata, **non_kwargs, **kwargs)
        for attr, val in attr_dict.items():
            setattr(model, attr, val)

        if use_cuda:
            model.model.load_state_dict(torch.load(model_path))
            model.model.cuda()
        else:
            device = torch.device("cpu")
            model.model.load_state_dict(
                torch.load(model_path, map_location=device))

        model.model.eval()
        model._validate_anndata(adata)

        return model
Beispiel #6
0
def test_scvi():
    n_latent = 5
    adata = synthetic_iid()
    model = SCVI(adata, n_latent=n_latent)
    model.train(1, frequency=1, train_size=0.5)
    assert model.is_trained is True
    z = model.get_latent_representation()
    assert z.shape == (adata.shape[0], n_latent)
    # len of history should be 2 since metrics is also run once at the very end after training
    assert len(model.history["elbo_train_set"]) == 2
    model.get_elbo()
    model.get_marginal_ll()
    model.get_reconstruction_error()
    model.get_normalized_expression(transform_batch="batch_1")

    adata2 = synthetic_iid()
    model.get_elbo(adata2)
    model.get_marginal_ll(adata2)
    model.get_reconstruction_error(adata2)
    latent = model.get_latent_representation(adata2, indices=[1, 2, 3])
    assert latent.shape == (3, n_latent)
    denoised = model.get_normalized_expression(adata2)
    assert denoised.shape == adata.shape

    denoised = model.get_normalized_expression(adata2,
                                               indices=[1, 2, 3],
                                               transform_batch="batch_1")
    denoised = model.get_normalized_expression(
        adata2, indices=[1, 2, 3], transform_batch=["batch_0", "batch_1"])
    assert denoised.shape == (3, adata2.n_vars)
    sample = model.posterior_predictive_sample(adata2)
    assert sample.shape == adata2.shape
    sample = model.posterior_predictive_sample(adata2,
                                               indices=[1, 2, 3],
                                               gene_list=["1", "2"])
    assert sample.shape == (3, 2)
    sample = model.posterior_predictive_sample(adata2,
                                               indices=[1, 2, 3],
                                               gene_list=["1", "2"],
                                               n_samples=3)
    assert sample.shape == (3, 2, 3)

    model.get_feature_correlation_matrix(correlation_type="pearson")
    model.get_feature_correlation_matrix(
        adata2,
        indices=[1, 2, 3],
        correlation_type="spearman",
        rna_size_factor=500,
        n_samples=5,
    )
    model.get_feature_correlation_matrix(
        adata2,
        indices=[1, 2, 3],
        correlation_type="spearman",
        rna_size_factor=500,
        n_samples=5,
        transform_batch=["batch_0", "batch_1"],
    )
    params = model.get_likelihood_parameters()
    assert params["mean"].shape == adata.shape
    assert (params["mean"].shape == params["dispersions"].shape ==
            params["dropout"].shape)
    params = model.get_likelihood_parameters(adata2, indices=[1, 2, 3])
    assert params["mean"].shape == (3, adata.n_vars)
    params = model.get_likelihood_parameters(adata2,
                                             indices=[1, 2, 3],
                                             n_samples=3,
                                             give_mean=True)
    assert params["mean"].shape == (3, adata.n_vars)
    model.get_latent_library_size()
    model.get_latent_library_size(adata2, indices=[1, 2, 3])

    # test transfer_anndata_setup
    adata2 = synthetic_iid(run_setup_anndata=False)
    transfer_anndata_setup(adata, adata2)
    model.get_elbo(adata2)

    # test automatic transfer_anndata_setup + on a view
    adata = synthetic_iid()
    model = SCVI(adata)
    adata2 = synthetic_iid(run_setup_anndata=False)
    model.get_elbo(adata2[:10])

    # test that we catch incorrect mappings
    adata = synthetic_iid()
    adata2 = synthetic_iid(run_setup_anndata=False)
    transfer_anndata_setup(adata, adata2)
    adata2.uns["_scvi"]["categorical_mappings"]["_scvi_labels"][
        "mapping"] = np.array(["label_1", "label_0", "label_2"])
    with pytest.raises(ValueError):
        model.get_elbo(adata2)

    # test mismatched categories raises ValueError
    adata2 = synthetic_iid(run_setup_anndata=False)
    adata2.obs.labels.cat.rename_categories(["a", "b", "c"], inplace=True)
    with pytest.raises(ValueError):
        model.get_elbo(adata2)

    # test differential expression
    model.differential_expression(groupby="labels", group1="label_1")
    model.differential_expression(groupby="labels",
                                  group1="label_1",
                                  group2="label_2",
                                  mode="change")
    model.differential_expression(groupby="labels")
    model.differential_expression(idx1=[0, 1, 2], idx2=[3, 4, 5])
    model.differential_expression(idx1=[0, 1, 2])

    # transform batch works with all different types
    a = synthetic_iid(run_setup_anndata=False)
    batch = np.zeros(a.n_obs)
    batch[:64] += 1
    a.obs["batch"] = batch
    setup_anndata(a, batch_key="batch")
    m = SCVI(a)
    m.train(1, train_size=0.5)
    m.get_normalized_expression(transform_batch=1)
    m.get_normalized_expression(transform_batch=[0, 1])
Beispiel #7
0
def test_transfer_anndata_setup():
    # test transfer_anndata function
    adata1 = synthetic_iid(run_setup_anndata=False)
    adata2 = synthetic_iid(run_setup_anndata=False)
    adata2.X = adata1.X
    setup_anndata(adata1)
    transfer_anndata_setup(adata1, adata2)
    np.testing.assert_array_equal(adata1.obs["_scvi_local_l_mean"],
                                  adata2.obs["_scvi_local_l_mean"])

    # test if layer was used initially, again used in transfer setup
    adata1 = synthetic_iid(run_setup_anndata=False)
    adata2 = synthetic_iid(run_setup_anndata=False)
    raw_counts = adata1.X.copy()
    adata1.layers["raw"] = raw_counts
    adata2.layers["raw"] = raw_counts
    zeros = np.zeros_like(adata1.X)
    ones = np.ones_like(adata1.X)
    adata1.X = zeros
    adata2.X = ones
    setup_anndata(adata1, layer="raw")
    transfer_anndata_setup(adata1, adata2)
    np.testing.assert_array_equal(adata1.obs["_scvi_local_l_mean"],
                                  adata2.obs["_scvi_local_l_mean"])

    # test that an unknown batch throws an error
    adata1 = synthetic_iid()
    adata2 = synthetic_iid(run_setup_anndata=False)
    adata2.obs["batch"] = [2] * adata2.n_obs
    with pytest.raises(ValueError):
        transfer_anndata_setup(adata1, adata2)

    # TODO: test that a batch with wrong dtype throws an error
    # adata1 = synthetic_iid()
    # adata2 = synthetic_iid(run_setup_anndata=False)
    # adata2.obs["batch"] = ["0"] * adata2.n_obs
    # with pytest.raises(ValueError):
    #     transfer_anndata_setup(adata1, adata2)

    # test that an unknown label throws an error
    adata1 = synthetic_iid()
    adata2 = synthetic_iid(run_setup_anndata=False)
    adata2.obs["labels"] = ["label_123"] * adata2.n_obs
    with pytest.raises(ValueError):
        transfer_anndata_setup(adata1, adata2)

    # test that correct mapping was applied
    adata1 = synthetic_iid()
    adata2 = synthetic_iid(run_setup_anndata=False)
    adata2.obs["labels"] = ["label_1"] * adata2.n_obs
    transfer_anndata_setup(adata1, adata2)
    labels_mapping = adata1.uns["_scvi"]["categorical_mappings"][
        "_scvi_labels"]["mapping"]
    correct_label = np.where(labels_mapping == "label_1")[0][0]
    adata2.obs["_scvi_labels"][0] == correct_label

    # test that transfer_anndata_setup correctly looks for adata.obs['batch']
    adata1 = synthetic_iid()
    adata2 = synthetic_iid(run_setup_anndata=False)
    del adata2.obs["batch"]
    with pytest.raises(KeyError):
        transfer_anndata_setup(adata1, adata2)

    # test that transfer_anndata_setup assigns same batch and label to cells
    # if the original anndata was also same batch and label
    adata1 = synthetic_iid(run_setup_anndata=False)
    setup_anndata(adata1)
    adata2 = synthetic_iid(run_setup_anndata=False)
    del adata2.obs["batch"]
    transfer_anndata_setup(adata1, adata2)
    assert adata2.obs["_scvi_batch"][0] == 0
    assert adata2.obs["_scvi_labels"][0] == 0
Beispiel #8
0
    def load(
        cls,
        dir_path: str,
        adata: Optional[AnnData] = None,
        use_gpu: Optional[Union[str, int, bool]] = None,
    ):
        """
        Instantiate a model from the saved output.

        Parameters
        ----------
        dir_path
            Path to saved outputs.
        adata
            AnnData organized in the same way as data used to train model.
            It is not necessary to run :func:`~scvi.data.setup_anndata`,
            as AnnData is validated against the saved `scvi` setup dictionary.
            If None, will check for and load anndata saved with the model.
        use_gpu
            Load model on default GPU if available (if None or True),
            or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False).

        Returns
        -------
        Model with loaded state dictionaries.

        Examples
        --------
        >>> vae = SCVI.load(save_path, adata)
        >>> vae.get_latent_representation()
        """
        load_adata = adata is None
        use_gpu, device = parse_use_gpu_arg(use_gpu)

        (
            scvi_setup_dict,
            attr_dict,
            var_names,
            model_state_dict,
            new_adata,
        ) = _load_saved_files(dir_path, load_adata, map_location=device)
        adata = new_adata if new_adata is not None else adata

        _validate_var_names(adata, var_names)
        transfer_anndata_setup(scvi_setup_dict, adata)
        model = _initialize_model(cls, adata, attr_dict)

        # set saved attrs for loaded model
        for attr, val in attr_dict.items():
            setattr(model, attr, val)

        # some Pyro modules with AutoGuides may need one training step
        try:
            model.module.load_state_dict(model_state_dict)
        except RuntimeError as err:
            if isinstance(model.module, PyroBaseModuleClass):
                logger.info("Preparing underlying module for load")
                model.train(max_steps=1)
                pyro.clear_param_store()
                model.module.load_state_dict(model_state_dict)
            else:
                raise err

        model.to_device(device)
        model.module.eval()
        model._validate_anndata(adata)
        return model
Beispiel #9
0
    def load(
        cls,
        dir_path: str,
        prefix: Optional[str] = None,
        adata: Optional[AnnData] = None,
        use_gpu: Optional[Union[str, int, bool]] = None,
    ):
        """
        Instantiate a model from the saved output.

        Parameters
        ----------
        dir_path
            Path to saved outputs.
        prefix
            Prefix of saved file names.
        adata
            AnnData organized in the same way as data used to train model.
            It is not necessary to run setup_anndata,
            as AnnData is validated against the saved `scvi` setup dictionary.
            If None, will check for and load anndata saved with the model.
        use_gpu
            Load model on default GPU if available (if None or True),
            or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False).

        Returns
        -------
        Model with loaded state dictionaries.

        Examples
        --------
        >>> model = ModelClass.load(save_path, adata) # use the name of the model class used to save
        >>> model.get_....
        """
        load_adata = adata is None
        use_gpu, device = parse_use_gpu_arg(use_gpu)

        (
            attr_dict,
            var_names,
            model_state_dict,
            new_adata,
        ) = _load_saved_files(dir_path, load_adata, map_location=device, prefix=prefix)
        adata = new_adata if new_adata is not None else adata

        scvi_setup_dict = attr_dict.pop("scvi_setup_dict_")

        # Filter out keys that are no longer populated by setup_anndata.
        # TODO(jhong): remove hack with setup_anndata refactor.
        deprecated_keys = {"local_l_mean", "local_l_var"}
        scvi_setup_dict["data_registry"] = {
            k: v
            for k, v in scvi_setup_dict["data_registry"].items()
            if k not in deprecated_keys
        }

        _validate_var_names(adata, var_names)
        transfer_anndata_setup(scvi_setup_dict, adata)
        model = _initialize_model(cls, adata, attr_dict)

        # set saved attrs for loaded model
        for attr, val in attr_dict.items():
            setattr(model, attr, val)

        # some Pyro modules with AutoGuides may need one training step
        try:
            model.module.load_state_dict(model_state_dict)
        except RuntimeError as err:
            if isinstance(model.module, PyroBaseModuleClass):
                old_history = model.history_.copy()
                logger.info("Preparing underlying module for load")
                model.train(max_steps=1)
                model.history_ = old_history
                pyro.clear_param_store()
                model.module.load_state_dict(model_state_dict)
            else:
                raise err

        model.to_device(device)
        model.module.eval()
        model._validate_anndata(adata)
        return model
Beispiel #10
0
    def load_query_data(
        cls,
        adata: AnnData,
        reference_model: Union[str, BaseModelClass],
        use_cuda: bool = True,
        freeze_dropout: bool = False,
        freeze_expression: bool = True,
        freeze_batchnorm_encoder: bool = True,
        freeze_batchnorm_decoder: bool = False,
    ):
        """
        Online update of a reference model with scArches algorithm [Lotfollahi20]_.

        Parameters
        ----------
        adata
            AnnData organized in the same way as data used to train model.
            It is not necessary to run :func:`~scvi.data.setup_anndata`,
            as AnnData is validated against the saved `scvi` setup dictionary.
        reference_model
            Either an already instantiated model of the same class, or a path to
            saved outputs for reference model.
        use_cuda
            Whether to load model on GPU.
        freeze_dropout
            Whether to freeze dropout during training
        freeze_expression
            Freeze neurons corersponding to expression in first layer
        freeze_batchnorm_encoder
            Whether to freeze batchnorm weight and bias during training for encoder
        freeze_batchnorm_decoder
            Whether to freeze batchnorm weight and bias during training for decoder
        """
        use_cuda = use_cuda and torch.cuda.is_available()

        if isinstance(reference_model, str):
            map_location = torch.device("cpu") if use_cuda is False else None
            (
                scvi_setup_dict,
                attr_dict,
                var_names,
                load_state_dict,
                _,
            ) = _load_saved_files(reference_model,
                                  load_adata=False,
                                  map_location=map_location)
        else:
            attr_dict = reference_model._get_user_attributes()
            attr_dict = {a[0]: a[1] for a in attr_dict if a[0][-1] == "_"}
            scvi_setup_dict = attr_dict.pop("scvi_setup_dict_")
            var_names = reference_model.adata.var_names
            load_state_dict = reference_model.model.state_dict().copy()

        _validate_var_names(adata, var_names)
        transfer_anndata_setup(scvi_setup_dict, adata, extend_categories=True)
        # for scanvi, any new labels in query cannot be used to extend the model
        adata.uns["_scvi"]["summary_stats"]["n_labels"] = scvi_setup_dict[
            "summary_stats"]["n_labels"]

        model = _initialize_model(cls, adata, attr_dict, use_cuda)

        # set saved attrs for loaded model
        for attr, val in attr_dict.items():
            setattr(model, attr, val)

        if use_cuda:
            model.model.cuda()

        # model tweaking
        new_state_dict = model.model.state_dict()
        for key, load_ten in load_state_dict.items():
            new_ten = new_state_dict[key]
            if new_ten.size() == load_ten.size():
                continue
            # new categoricals changed size
            else:
                dim_diff = new_ten.size()[-1] - load_ten.size()[-1]
                fixed_ten = torch.cat([load_ten, new_ten[..., -dim_diff:]],
                                      dim=-1)
                load_state_dict[key] = fixed_ten

        model.model.load_state_dict(load_state_dict)
        model.model.eval()

        _set_params_online_update(
            model.model,
            freeze_batchnorm_encoder=freeze_batchnorm_encoder,
            freeze_batchnorm_decoder=freeze_batchnorm_decoder,
            freeze_dropout=freeze_dropout,
            freeze_expression=freeze_expression,
        )
        model.is_trained_ = False

        return model
Beispiel #11
0
def test_scvi():
    n_latent = 5
    adata = synthetic_iid()
    model = SCVI(adata, n_latent=n_latent)
    model.train(1)
    assert model.is_trained is True
    z = model.get_latent_representation()
    assert z.shape == (adata.shape[0], n_latent)
    model.get_elbo()
    model.get_marginal_ll()
    model.get_reconstruction_error()
    model.get_normalized_expression()

    adata2 = synthetic_iid()
    model.get_elbo(adata2)
    model.get_marginal_ll(adata2)
    model.get_reconstruction_error(adata2)
    latent = model.get_latent_representation(adata2, indices=[1, 2, 3])
    assert latent.shape == (3, n_latent)
    denoised = model.get_normalized_expression(adata2)
    assert denoised.shape == adata.shape

    denoised = model.get_normalized_expression(adata2,
                                               indices=[1, 2, 3],
                                               transform_batch=1)
    assert denoised.shape == (3, adata2.n_vars)
    sample = model.posterior_predictive_sample(adata2)
    assert sample.shape == adata2.shape
    sample = model.posterior_predictive_sample(adata2,
                                               indices=[1, 2, 3],
                                               gene_list=["1", "2"])
    assert sample.shape == (3, 2)
    sample = model.posterior_predictive_sample(adata2,
                                               indices=[1, 2, 3],
                                               gene_list=["1", "2"],
                                               n_samples=3)
    assert sample.shape == (3, 2, 3)

    model.get_feature_correlation_matrix(correlation_type="pearson")
    model.get_feature_correlation_matrix(
        adata2,
        indices=[1, 2, 3],
        correlation_type="spearman",
        rna_size_factor=500,
        n_samples=5,
    )
    params = model.get_likelihood_parameters()
    assert params["mean"].shape == adata.shape
    assert (params["mean"].shape == params["dispersions"].shape ==
            params["dropout"].shape)
    params = model.get_likelihood_parameters(adata2, indices=[1, 2, 3])
    assert params["mean"].shape == (3, adata.n_vars)
    params = model.get_likelihood_parameters(adata2,
                                             indices=[1, 2, 3],
                                             n_samples=3,
                                             give_mean=True)
    assert params["mean"].shape == (3, adata.n_vars)
    model.get_latent_library_size()
    model.get_latent_library_size(adata2, indices=[1, 2, 3])

    # test transfer_anndata_setup
    adata2 = synthetic_iid(run_setup_anndata=False)
    transfer_anndata_setup(adata, adata2)
    model.get_elbo(adata2)

    # test automatic transfer_anndata_setup + on a view
    adata = synthetic_iid()
    model = SCVI(adata)
    adata2 = synthetic_iid(run_setup_anndata=False)
    model.get_elbo(adata2[:10])

    # test that we catch incorrect mappings
    adata = synthetic_iid()
    adata2 = synthetic_iid(run_setup_anndata=False)
    transfer_anndata_setup(adata, adata2)
    adata2.uns["_scvi"]["categorical_mappings"]["_scvi_labels"][
        "mapping"] = pd.Index(
            data=["undefined_1", "undefined_0", "undefined_2"])
    with pytest.raises(ValueError):
        model.get_elbo(adata2)

    # test mismatched categories raises ValueError
    adata2 = synthetic_iid(run_setup_anndata=False)
    adata2.obs.labels.cat.rename_categories(["a", "b", "c"], inplace=True)
    with pytest.raises(ValueError):
        model.get_elbo(adata2)

    # test differential expression
    model.differential_expression(groupby="labels", group1="undefined_1")
    model.differential_expression(groupby="labels",
                                  group1="undefined_1",
                                  group2="undefined_2",
                                  mode="change")
    model.differential_expression(groupby="labels")
    model.differential_expression(idx1=[0, 1, 2], idx2=[3, 4, 5])
    model.differential_expression(idx1=[0, 1, 2])
Beispiel #12
0
    def from_scvi_model(
        cls,
        scvi_model: SCVI,
        adata: Optional[AnnData] = None,
        restrict_to_batch: Optional[str] = None,
        doublet_ratio: int = 2,
        **classifier_kwargs,
    ):
        """
        Instantiate a SOLO model from an scvi model.

        Parameters
        ----------
        scvi_model
            Pre-trained model of :class:`~scvi.model.SCVI`. The
            adata object used to initialize this model should have only
            been setup with count data, and optionally a `batch_key`;
            i.e., no extra covariates or labels, etc.
        adata
            Optional anndata to use that is compatible with scvi_model.
        restrict_to_batch
            Batch category in `batch_key` used to setup adata for scvi_model
            to restrict Solo model to. This allows to train a Solo model on
            one batch of a scvi_model that was trained on multiple batches.
        doublet_ratio
            Ratio of generated doublets to produce relative to number of
            cells in adata or length of indices, if not `None`.
        **classifier_kwargs
            Keyword args for :class:`~scvi.module.Classifier`

        Returns
        -------
        SOLO model
        """
        _validate_scvi_model(scvi_model, restrict_to_batch=restrict_to_batch)
        orig_adata = scvi_model.adata
        orig_batch_key = scvi_model.scvi_setup_dict_["categorical_mappings"][
            "_scvi_batch"]["original_key"]

        if adata is not None:
            transfer_anndata_setup(orig_adata, adata)
        else:
            adata = orig_adata

        if restrict_to_batch is not None:
            batch_mask = adata.obs[orig_batch_key] == restrict_to_batch
            if np.sum(batch_mask) == 0:
                raise ValueError(
                    "Batch category given to restrict_to_batch not found.\n" +
                    "Available categories: {}".format(
                        adata.obs[orig_batch_key].astype(
                            "category").cat.categories))
            # indices in adata with restrict_to_batch category
            batch_indices = np.where(batch_mask)[0]
        else:
            # use all indices
            batch_indices = None

        # anndata with only generated doublets
        doublet_adata = cls.create_doublets(adata,
                                            indices=batch_indices,
                                            doublet_ratio=doublet_ratio)
        # if scvi wasn't trained with batch correction having the
        # zeros here does nothing.
        doublet_adata.obs[orig_batch_key] = (
            restrict_to_batch if restrict_to_batch is not None else 0)

        # if model is using observed lib size, needs to get lib sample
        # which is just observed lib size on log scale
        give_mean_lib = not scvi_model.module.use_observed_lib_size

        # get latent representations and make input anndata
        latent_rep = scvi_model.get_latent_representation(
            adata, indices=batch_indices)
        lib_size = scvi_model.get_latent_library_size(adata,
                                                      indices=batch_indices,
                                                      give_mean=give_mean_lib)
        latent_adata = AnnData(
            np.concatenate([latent_rep, np.log(lib_size)], axis=1))
        latent_adata.obs[LABELS_KEY] = "singlet"
        orig_obs_names = adata.obs_names
        latent_adata.obs_names = (orig_obs_names[batch_indices]
                                  if batch_indices is not None else
                                  orig_obs_names)

        logger.info("Creating doublets, preparing SOLO model.")
        f = io.StringIO()
        with redirect_stdout(f):
            setup_anndata(doublet_adata, batch_key=orig_batch_key)
            doublet_latent_rep = scvi_model.get_latent_representation(
                doublet_adata)
            doublet_lib_size = scvi_model.get_latent_library_size(
                doublet_adata, give_mean=give_mean_lib)
            doublet_adata = AnnData(
                np.concatenate([doublet_latent_rep,
                                np.log(doublet_lib_size)],
                               axis=1))
            doublet_adata.obs[LABELS_KEY] = "doublet"

            full_adata = latent_adata.concatenate(doublet_adata)
            setup_anndata(full_adata, labels_key=LABELS_KEY)
        return cls(full_adata, **classifier_kwargs)
Beispiel #13
0
    def load(
        cls,
        dir_path: str,
        adata_seq: Optional[AnnData] = None,
        adata_spatial: Optional[AnnData] = None,
        use_gpu: Optional[Union[str, int, bool]] = None,
    ):
        """
        Instantiate a model from the saved output.

        Parameters
        ----------
        adata_seq
            AnnData organized in the same way as data used to train model.
            It is not necessary to run :func:`~scvi.data.setup_anndata`,
            as AnnData is validated against the saved `scvi` setup dictionary.
            AnnData must be registered via :func:`~scvi.data.setup_anndata`.
        adata_spatial
            AnnData organized in the same way as data used to train model.
            If None, will check for and load anndata saved with the model.
        dir_path
            Path to saved outputs.
        use_gpu
            Load model on default GPU if available (if None or True),
            or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False).

        Returns
        -------
        Model with loaded state dictionaries.

        Examples
        --------
        >>> vae = GIMVI.load(adata_seq, adata_spatial, save_path)
        >>> vae.get_latent_representation()
        """
        model_path = os.path.join(dir_path, "model_params.pt")
        setup_dict_path = os.path.join(dir_path, "attr.pkl")
        seq_data_path = os.path.join(dir_path, "adata_seq.h5ad")
        spatial_data_path = os.path.join(dir_path, "adata_spatial.h5ad")
        seq_var_names_path = os.path.join(dir_path, "var_names_seq.csv")
        spatial_var_names_path = os.path.join(dir_path, "var_names_spatial.csv")

        if adata_seq is None and os.path.exists(seq_data_path):
            adata_seq = read(seq_data_path)
        elif adata_seq is None and not os.path.exists(seq_data_path):
            raise ValueError(
                "Save path contains no saved anndata and no adata was passed."
            )
        if adata_spatial is None and os.path.exists(spatial_data_path):
            adata_spatial = read(spatial_data_path)
        elif adata_spatial is None and not os.path.exists(spatial_data_path):
            raise ValueError(
                "Save path contains no saved anndata and no adata was passed."
            )
        adatas = [adata_seq, adata_spatial]

        seq_var_names = np.genfromtxt(seq_var_names_path, delimiter=",", dtype=str)
        spatial_var_names = np.genfromtxt(
            spatial_var_names_path, delimiter=",", dtype=str
        )
        var_names = [seq_var_names, spatial_var_names]

        for i, adata in enumerate(adatas):
            saved_var_names = var_names[i]
            user_var_names = adata.var_names.astype(str)
            if not np.array_equal(saved_var_names, user_var_names):
                warnings.warn(
                    "var_names for adata passed in does not match var_names of "
                    "adata used to train the model. For valid results, the vars "
                    "need to be the same and in the same order as the adata used to train the model."
                )

        with open(setup_dict_path, "rb") as handle:
            attr_dict = pickle.load(handle)

        scvi_setup_dicts = attr_dict.pop("scvi_setup_dicts_")
        transfer_anndata_setup(scvi_setup_dicts["seq"], adata_seq)
        transfer_anndata_setup(scvi_setup_dicts["spatial"], adata_spatial)

        # get the parameters for the class init signiture
        init_params = attr_dict.pop("init_params_")

        # new saving and loading, enable backwards compatibility
        if "non_kwargs" in init_params.keys():
            # grab all the parameters execept for kwargs (is a dict)
            non_kwargs = init_params["non_kwargs"]
            kwargs = init_params["kwargs"]

            # expand out kwargs
            kwargs = {k: v for (i, j) in kwargs.items() for (k, v) in j.items()}
        else:
            # grab all the parameters execept for kwargs (is a dict)
            non_kwargs = {
                k: v for k, v in init_params.items() if not isinstance(v, dict)
            }
            kwargs = {k: v for k, v in init_params.items() if isinstance(v, dict)}
            kwargs = {k: v for (i, j) in kwargs.items() for (k, v) in j.items()}
        model = cls(adata_seq, adata_spatial, **non_kwargs, **kwargs)

        for attr, val in attr_dict.items():
            setattr(model, attr, val)

        _, device = parse_use_gpu_arg(use_gpu)
        model.module.load_state_dict(torch.load(model_path, map_location=device))
        model.module.eval()
        model.to_device(device)
        return model
Beispiel #14
0
def test_scvi(save_path):
    n_latent = 5
    adata = synthetic_iid()
    model = SCVI(adata, n_latent=n_latent)
    model.train(1, check_val_every_n_epoch=1, train_size=0.5)

    model = SCVI(adata, n_latent=n_latent, var_activation=Softplus())
    model.train(1, check_val_every_n_epoch=1, train_size=0.5)

    # tests __repr__
    print(model)

    assert model.is_trained is True
    z = model.get_latent_representation()
    assert z.shape == (adata.shape[0], n_latent)
    assert len(model.history["elbo_train"]) == 1
    model.get_elbo()
    model.get_marginal_ll(n_mc_samples=3)
    model.get_reconstruction_error()
    model.get_normalized_expression(transform_batch="batch_1")

    adata2 = synthetic_iid()
    model.get_elbo(adata2)
    model.get_marginal_ll(adata2, n_mc_samples=3)
    model.get_reconstruction_error(adata2)
    latent = model.get_latent_representation(adata2, indices=[1, 2, 3])
    assert latent.shape == (3, n_latent)
    denoised = model.get_normalized_expression(adata2)
    assert denoised.shape == adata.shape

    denoised = model.get_normalized_expression(
        adata2, indices=[1, 2, 3], transform_batch="batch_1"
    )
    denoised = model.get_normalized_expression(
        adata2, indices=[1, 2, 3], transform_batch=["batch_0", "batch_1"]
    )
    assert denoised.shape == (3, adata2.n_vars)
    sample = model.posterior_predictive_sample(adata2)
    assert sample.shape == adata2.shape
    sample = model.posterior_predictive_sample(
        adata2, indices=[1, 2, 3], gene_list=["1", "2"]
    )
    assert sample.shape == (3, 2)
    sample = model.posterior_predictive_sample(
        adata2, indices=[1, 2, 3], gene_list=["1", "2"], n_samples=3
    )
    assert sample.shape == (3, 2, 3)

    model.get_feature_correlation_matrix(correlation_type="pearson")
    model.get_feature_correlation_matrix(
        adata2,
        indices=[1, 2, 3],
        correlation_type="spearman",
        rna_size_factor=500,
        n_samples=5,
    )
    model.get_feature_correlation_matrix(
        adata2,
        indices=[1, 2, 3],
        correlation_type="spearman",
        rna_size_factor=500,
        n_samples=5,
        transform_batch=["batch_0", "batch_1"],
    )
    params = model.get_likelihood_parameters()
    assert params["mean"].shape == adata.shape
    assert (
        params["mean"].shape == params["dispersions"].shape == params["dropout"].shape
    )
    params = model.get_likelihood_parameters(adata2, indices=[1, 2, 3])
    assert params["mean"].shape == (3, adata.n_vars)
    params = model.get_likelihood_parameters(
        adata2, indices=[1, 2, 3], n_samples=3, give_mean=True
    )
    assert params["mean"].shape == (3, adata.n_vars)
    model.get_latent_library_size()
    model.get_latent_library_size(adata2, indices=[1, 2, 3])

    # test transfer_anndata_setup
    adata2 = synthetic_iid(run_setup_anndata=False)
    transfer_anndata_setup(adata, adata2)
    model.get_elbo(adata2)

    # test automatic transfer_anndata_setup + on a view
    adata = synthetic_iid()
    model = SCVI(adata)
    adata2 = synthetic_iid(run_setup_anndata=False)
    model.get_elbo(adata2[:10])

    # test that we catch incorrect mappings
    adata = synthetic_iid()
    adata2 = synthetic_iid(run_setup_anndata=False)
    transfer_anndata_setup(adata, adata2)
    adata2.uns["_scvi"]["categorical_mappings"]["_scvi_labels"]["mapping"] = np.array(
        ["label_4", "label_0", "label_2"]
    )
    with pytest.raises(ValueError):
        model.get_elbo(adata2)

    # test that same mapping different order doesn't raise error
    adata = synthetic_iid()
    adata2 = synthetic_iid(run_setup_anndata=False)
    transfer_anndata_setup(adata, adata2)
    adata2.uns["_scvi"]["categorical_mappings"]["_scvi_labels"]["mapping"] = np.array(
        ["label_1", "label_0", "label_2"]
    )
    model.get_elbo(adata2)  # should automatically transfer setup

    # test mismatched categories raises ValueError
    adata2 = synthetic_iid(run_setup_anndata=False)
    adata2.obs.labels.cat.rename_categories(["a", "b", "c"], inplace=True)
    with pytest.raises(ValueError):
        model.get_elbo(adata2)

    # test differential expression
    model.differential_expression(groupby="labels", group1="label_1")
    model.differential_expression(
        groupby="labels", group1="label_1", group2="label_2", mode="change"
    )
    model.differential_expression(groupby="labels")
    model.differential_expression(idx1=[0, 1, 2], idx2=[3, 4, 5])
    model.differential_expression(idx1=[0, 1, 2])

    # transform batch works with all different types
    a = synthetic_iid(run_setup_anndata=False)
    batch = np.zeros(a.n_obs)
    batch[:64] += 1
    a.obs["batch"] = batch
    setup_anndata(a, batch_key="batch")
    m = SCVI(a)
    m.train(1, train_size=0.5)
    m.get_normalized_expression(transform_batch=1)
    m.get_normalized_expression(transform_batch=[0, 1])

    # test get_likelihood_parameters() when dispersion=='gene-cell'
    model = SCVI(adata, dispersion="gene-cell")
    model.get_likelihood_parameters()

    # test train callbacks work
    a = synthetic_iid()
    m = scvi.model.SCVI(a)
    lr_monitor = LearningRateMonitor()
    m.train(
        callbacks=[lr_monitor],
        max_epochs=10,
        log_every_n_steps=1,
        plan_kwargs={"reduce_lr_on_plateau": True},
    )
    assert "lr-Adam" in m.history.keys()
Beispiel #15
0
    def load_query_data(
        cls,
        adata: AnnData,
        reference_model: Union[str, BaseModelClass],
        inplace_subset_query_vars: bool = False,
        use_gpu: Optional[Union[str, int, bool]] = None,
        unfrozen: bool = False,
        freeze_dropout: bool = False,
        freeze_expression: bool = True,
        freeze_decoder_first_layer: bool = True,
        freeze_batchnorm_encoder: bool = True,
        freeze_batchnorm_decoder: bool = False,
        freeze_classifier: bool = True,
    ):
        """
        Online update of a reference model with scArches algorithm [Lotfollahi21]_.

        Parameters
        ----------
        adata
            AnnData organized in the same way as data used to train model.
            It is not necessary to run setup_anndata,
            as AnnData is validated against the saved `scvi` setup dictionary.
        reference_model
            Either an already instantiated model of the same class, or a path to
            saved outputs for reference model.
        inplace_subset_query_vars
            Whether to subset and rearrange query vars inplace based on vars used to
            train reference model.
        use_gpu
            Load model on default GPU if available (if None or True),
            or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False).
        unfrozen
            Override all other freeze options for a fully unfrozen model
        freeze_dropout
            Whether to freeze dropout during training
        freeze_expression
            Freeze neurons corersponding to expression in first layer
        freeze_decoder_first_layer
            Freeze neurons corersponding to first layer in decoder
        freeze_batchnorm_encoder
            Whether to freeze batchnorm weight and bias during training for encoder
        freeze_batchnorm_decoder
            Whether to freeze batchnorm weight and bias during training for decoder
        freeze_classifier
            Whether to freeze classifier completely. Only applies to `SCANVI`.
        """
        use_gpu, device = parse_use_gpu_arg(use_gpu)
        if isinstance(reference_model, str):
            (
                attr_dict,
                var_names,
                load_state_dict,
                _,
            ) = _load_saved_files(reference_model,
                                  load_adata=False,
                                  map_location=device)
        else:
            attr_dict = reference_model._get_user_attributes()
            attr_dict = {a[0]: a[1] for a in attr_dict if a[0][-1] == "_"}
            var_names = reference_model.adata.var_names
            load_state_dict = deepcopy(reference_model.module.state_dict())

        scvi_setup_dict = attr_dict.pop("scvi_setup_dict_")

        if inplace_subset_query_vars:
            logger.debug("Subsetting query vars to reference vars.")
            adata._inplace_subset_var(var_names)
        _validate_var_names(adata, var_names)

        version_split = scvi_setup_dict["scvi_version"].split(".")
        if version_split[1] < "8" and version_split[0] == "0":
            warnings.warn(
                "Query integration should be performed using models trained with version >= 0.8"
            )

        transfer_anndata_setup(scvi_setup_dict, adata, extend_categories=True)

        model = _initialize_model(cls, adata, attr_dict)

        # set saved attrs for loaded model
        for attr, val in attr_dict.items():
            setattr(model, attr, val)

        model.to_device(device)

        # model tweaking
        new_state_dict = model.module.state_dict()
        for key, load_ten in load_state_dict.items():
            new_ten = new_state_dict[key]
            if new_ten.size() == load_ten.size():
                continue
            # new categoricals changed size
            else:
                dim_diff = new_ten.size()[-1] - load_ten.size()[-1]
                fixed_ten = torch.cat([load_ten, new_ten[..., -dim_diff:]],
                                      dim=-1)
                load_state_dict[key] = fixed_ten

        model.module.load_state_dict(load_state_dict)
        model.module.eval()

        _set_params_online_update(
            model.module,
            unfrozen=unfrozen,
            freeze_decoder_first_layer=freeze_decoder_first_layer,
            freeze_batchnorm_encoder=freeze_batchnorm_encoder,
            freeze_batchnorm_decoder=freeze_batchnorm_decoder,
            freeze_dropout=freeze_dropout,
            freeze_expression=freeze_expression,
            freeze_classifier=freeze_classifier,
        )
        model.is_trained_ = False

        return model
Beispiel #16
0
def test_totalvi(save_path):
    adata = synthetic_iid()
    n_obs = adata.n_obs
    n_vars = adata.n_vars
    n_proteins = adata.obsm["protein_expression"].shape[1]
    n_latent = 10

    model = TOTALVI(adata, n_latent=n_latent)
    model.train(1, train_size=0.5)
    assert model.is_trained is True
    z = model.get_latent_representation()
    assert z.shape == (n_obs, n_latent)
    model.get_elbo()
    model.get_marginal_ll(n_mc_samples=3)
    model.get_reconstruction_error()
    model.get_normalized_expression()
    model.get_normalized_expression(transform_batch=["batch_0", "batch_1"])
    model.get_latent_library_size()
    model.get_protein_foreground_probability()
    model.get_protein_foreground_probability(transform_batch=["batch_0", "batch_1"])
    post_pred = model.posterior_predictive_sample(n_samples=2)
    assert post_pred.shape == (n_obs, n_vars + n_proteins, 2)
    post_pred = model.posterior_predictive_sample(n_samples=1)
    assert post_pred.shape == (n_obs, n_vars + n_proteins)
    feature_correlation_matrix1 = model.get_feature_correlation_matrix(
        correlation_type="spearman"
    )
    feature_correlation_matrix1 = model.get_feature_correlation_matrix(
        correlation_type="spearman", transform_batch=["batch_0", "batch_1"]
    )
    feature_correlation_matrix2 = model.get_feature_correlation_matrix(
        correlation_type="pearson"
    )
    assert feature_correlation_matrix1.shape == (
        n_vars + n_proteins,
        n_vars + n_proteins,
    )
    assert feature_correlation_matrix2.shape == (
        n_vars + n_proteins,
        n_vars + n_proteins,
    )
    # model.get_likelihood_parameters()

    model.get_elbo(indices=model.validation_indices)
    model.get_marginal_ll(indices=model.validation_indices, n_mc_samples=3)
    model.get_reconstruction_error(indices=model.validation_indices)

    adata2 = synthetic_iid()
    norm_exp = model.get_normalized_expression(adata2, indices=[1, 2, 3])
    assert norm_exp[0].shape == (3, adata2.n_vars)
    assert norm_exp[1].shape == (3, adata2.obsm["protein_expression"].shape[1])

    latent_lib_size = model.get_latent_library_size(adata2, indices=[1, 2, 3])
    assert latent_lib_size.shape == (3, 1)

    pro_foreground_prob = model.get_protein_foreground_probability(
        adata2, indices=[1, 2, 3], protein_list=["1", "2"]
    )
    assert pro_foreground_prob.shape == (3, 2)
    model.posterior_predictive_sample(adata2)
    model.get_feature_correlation_matrix(adata2)
    # model.get_likelihood_parameters(adata2)

    # test transfer_anndata_setup + view
    adata2 = synthetic_iid(run_setup_anndata=False)
    transfer_anndata_setup(adata, adata2)
    model.get_elbo(adata2[:10])

    # test automatic transfer_anndata_setup
    adata = synthetic_iid()
    model = TOTALVI(adata)
    adata2 = synthetic_iid(run_setup_anndata=False)
    model.get_elbo(adata2)

    # test that we catch incorrect mappings
    adata = synthetic_iid()
    adata2 = synthetic_iid(run_setup_anndata=False)
    transfer_anndata_setup(adata, adata2)
    adata2.uns["_scvi"]["categorical_mappings"]["_scvi_labels"]["mapping"] = np.array(
        ["label_1", "label_0", "label_8"]
    )
    with pytest.raises(ValueError):
        model.get_elbo(adata2)

    # test that same mapping different order is okay
    adata = synthetic_iid()
    adata2 = synthetic_iid(run_setup_anndata=False)
    transfer_anndata_setup(adata, adata2)
    adata2.uns["_scvi"]["categorical_mappings"]["_scvi_labels"]["mapping"] = np.array(
        ["label_1", "label_0", "label_2"]
    )
    model.get_elbo(adata2)  # should automatically transfer setup

    # test that we catch missing proteins
    adata2 = synthetic_iid(run_setup_anndata=False)
    del adata2.obsm["protein_expression"]
    with pytest.raises(KeyError):
        model.get_elbo(adata2)
    model.differential_expression(groupby="labels", group1="label_1")
    model.differential_expression(groupby="labels", group1="label_1", group2="label_2")
    model.differential_expression(idx1=[0, 1, 2], idx2=[3, 4, 5])
    model.differential_expression(idx1=[0, 1, 2])
    model.differential_expression(groupby="labels")

    # test with missing proteins
    adata = scvi.data.pbmcs_10x_cite_seq(save_path=save_path, protein_join="outer")
    model = TOTALVI(adata)
    assert model.module.protein_batch_mask is not None
    model.train(1, train_size=0.5)
Beispiel #17
0
    def load(
        cls,
        dir_path: str,
        prefix: Optional[str] = None,
        adata_seq: Optional[AnnData] = None,
        adata_spatial: Optional[AnnData] = None,
        use_gpu: Optional[Union[str, int, bool]] = None,
    ):
        """
        Instantiate a model from the saved output.

        Parameters
        ----------
        dir_path
            Path to saved outputs.
        prefix
            Prefix of saved file names.
        adata_seq
            AnnData organized in the same way as data used to train model.
            It is not necessary to run :meth:`~scvi.external.GIMVI.setup_anndata`,
            as AnnData is validated against the saved `scvi` setup dictionary.
            AnnData must be registered via :meth:`~scvi.external.GIMVI.setup_anndata`.
        adata_spatial
            AnnData organized in the same way as data used to train model.
            If None, will check for and load anndata saved with the model.
        use_gpu
            Load model on default GPU if available (if None or True),
            or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False).

        Returns
        -------
        Model with loaded state dictionaries.

        Examples
        --------
        >>> vae = GIMVI.load(adata_seq, adata_spatial, save_path)
        >>> vae.get_latent_representation()
        """
        _, device = parse_use_gpu_arg(use_gpu)

        (
            attr_dict,
            seq_var_names,
            spatial_var_names,
            model_state_dict,
            loaded_adata_seq,
            loaded_adata_spatial,
        ) = _load_saved_gimvi_files(
            dir_path,
            adata_seq is None,
            adata_spatial is None,
            prefix=prefix,
            map_location=device,
        )
        adata_seq = loaded_adata_seq or adata_seq
        adata_spatial = loaded_adata_spatial or adata_spatial
        adatas = [adata_seq, adata_spatial]
        var_names = [seq_var_names, spatial_var_names]

        for i, adata in enumerate(adatas):
            saved_var_names = var_names[i]
            user_var_names = adata.var_names.astype(str)
            if not np.array_equal(saved_var_names, user_var_names):
                warnings.warn(
                    "var_names for adata passed in does not match var_names of "
                    "adata used to train the model. For valid results, the vars "
                    "need to be the same and in the same order as the adata used to train the model."
                )

        scvi_setup_dicts = attr_dict.pop("scvi_setup_dicts_")
        transfer_anndata_setup(scvi_setup_dicts["seq"], adata_seq)
        transfer_anndata_setup(scvi_setup_dicts["spatial"], adata_spatial)

        # get the parameters for the class init signiture
        init_params = attr_dict.pop("init_params_")

        # new saving and loading, enable backwards compatibility
        if "non_kwargs" in init_params.keys():
            # grab all the parameters execept for kwargs (is a dict)
            non_kwargs = init_params["non_kwargs"]
            kwargs = init_params["kwargs"]

            # expand out kwargs
            kwargs = {
                k: v
                for (i, j) in kwargs.items() for (k, v) in j.items()
            }
        else:
            # grab all the parameters execept for kwargs (is a dict)
            non_kwargs = {
                k: v
                for k, v in init_params.items() if not isinstance(v, dict)
            }
            kwargs = {
                k: v
                for k, v in init_params.items() if isinstance(v, dict)
            }
            kwargs = {
                k: v
                for (i, j) in kwargs.items() for (k, v) in j.items()
            }
        model = cls(adata_seq, adata_spatial, **non_kwargs, **kwargs)

        for attr, val in attr_dict.items():
            setattr(model, attr, val)

        model.module.load_state_dict(model_state_dict)
        model.module.eval()
        model.to_device(device)
        return model