예제 #1
0
    def __init__(
        self,
        adata: AnnData,
        n_latent: int = 20,
        gene_dispersion: Literal["gene", "gene-batch", "gene-label",
                                 "gene-cell"] = "gene",
        protein_dispersion: Literal["protein", "protein-batch",
                                    "protein-label"] = "protein",
        gene_likelihood: Literal["zinb", "nb"] = "nb",
        latent_distribution: Literal["normal", "ln"] = "normal",
        empirical_protein_background_prior: Optional[bool] = None,
        **model_kwargs,
    ):
        super(TOTALVI, self).__init__(adata)
        if "totalvi_batch_mask" in self.scvi_setup_dict_.keys():
            batch_mask = self.scvi_setup_dict_["totalvi_batch_mask"]
        else:
            batch_mask = None
        emp_prior = (empirical_protein_background_prior
                     if empirical_protein_background_prior is not None else
                     (self.summary_stats["n_proteins"] > 10))
        if emp_prior:
            prior_mean, prior_scale = _get_totalvi_protein_priors(adata)
        else:
            prior_mean, prior_scale = None, None

        n_cats_per_cov = (
            self.scvi_setup_dict_["extra_categoricals"]["n_cats_per_key"]
            if "extra_categoricals" in self.scvi_setup_dict_ else None)
        self.module = TOTALVAE(
            n_input_genes=self.summary_stats["n_vars"],
            n_input_proteins=self.summary_stats["n_proteins"],
            n_batch=self.summary_stats["n_batch"],
            n_latent=n_latent,
            n_continuous_cov=self.summary_stats["n_continuous_covs"],
            n_cats_per_cov=n_cats_per_cov,
            gene_dispersion=gene_dispersion,
            protein_dispersion=protein_dispersion,
            gene_likelihood=gene_likelihood,
            latent_distribution=latent_distribution,
            protein_batch_mask=batch_mask,
            protein_background_prior_mean=prior_mean,
            protein_background_prior_scale=prior_scale,
            **model_kwargs,
        )
        self._model_summary_string = (
            "TotalVI Model with the following params: \nn_latent: {}, "
            "gene_dispersion: {}, protein_dispersion: {}, gene_likelihood: {}, latent_distribution: {}"
        ).format(
            n_latent,
            gene_dispersion,
            protein_dispersion,
            gene_likelihood,
            latent_distribution,
        )
        self.init_params_ = self._get_init_params(locals())
예제 #2
0
    def __init__(
        self,
        adata: AnnData,
        n_latent: int = 20,
        gene_dispersion: Literal["gene", "gene-batch", "gene-label",
                                 "gene-cell"] = "gene",
        protein_dispersion: Literal["protein", "protein-batch",
                                    "protein-label"] = "protein",
        gene_likelihood: Literal["zinb", "nb"] = "nb",
        latent_distribution: Literal["normal", "ln"] = "normal",
        empirical_protein_background_prior: Optional[bool] = None,
        override_missing_proteins: bool = False,
        **model_kwargs,
    ):
        super(TOTALVI, self).__init__(adata)
        self.protein_state_registry = self.adata_manager.get_state_registry(
            REGISTRY_KEYS.PROTEIN_EXP_KEY)
        if (ProteinObsmField.PROTEIN_BATCH_MASK in self.protein_state_registry
                and not override_missing_proteins):
            batch_mask = self.protein_state_registry.protein_batch_mask
            msg = (
                "Some proteins have all 0 counts in some batches. " +
                "These proteins will be treated as missing measurements; however, "
                + "this can occur due to experimental design/biology. " +
                "Reinitialize the model with `override_missing_proteins=True`,"
                + "to override this behavior.")
            warnings.warn(msg, UserWarning)
            self._use_adversarial_classifier = True
        else:
            batch_mask = None
            self._use_adversarial_classifier = False

        emp_prior = (empirical_protein_background_prior
                     if empirical_protein_background_prior is not None else
                     (self.summary_stats.n_proteins > 10))
        if emp_prior:
            prior_mean, prior_scale = self._get_totalvi_protein_priors(adata)
        else:
            prior_mean, prior_scale = None, None

        n_cats_per_cov = (self.adata_manager.get_state_registry(
            REGISTRY_KEYS.CAT_COVS_KEY)[
                CategoricalJointObsField.N_CATS_PER_KEY]
                          if REGISTRY_KEYS.CAT_COVS_KEY
                          in self.adata_manager.data_registry else None)

        n_batch = self.summary_stats.n_batch
        use_size_factor_key = (REGISTRY_KEYS.SIZE_FACTOR_KEY
                               in self.adata_manager.data_registry)
        library_log_means, library_log_vars = None, None
        if not use_size_factor_key:
            library_log_means, library_log_vars = _init_library_size(
                self.adata_manager, n_batch)

        self.module = TOTALVAE(
            n_input_genes=self.summary_stats.n_vars,
            n_input_proteins=self.summary_stats.n_proteins,
            n_batch=n_batch,
            n_latent=n_latent,
            n_continuous_cov=self.summary_stats.get("n_extra_continuous_covs",
                                                    0),
            n_cats_per_cov=n_cats_per_cov,
            gene_dispersion=gene_dispersion,
            protein_dispersion=protein_dispersion,
            gene_likelihood=gene_likelihood,
            latent_distribution=latent_distribution,
            protein_batch_mask=batch_mask,
            protein_background_prior_mean=prior_mean,
            protein_background_prior_scale=prior_scale,
            use_size_factor_key=use_size_factor_key,
            library_log_means=library_log_means,
            library_log_vars=library_log_vars,
            **model_kwargs,
        )
        self._model_summary_string = (
            "TotalVI Model with the following params: \nn_latent: {}, "
            "gene_dispersion: {}, protein_dispersion: {}, gene_likelihood: {}, latent_distribution: {}"
        ).format(
            n_latent,
            gene_dispersion,
            protein_dispersion,
            gene_likelihood,
            latent_distribution,
        )
        self.init_params_ = self._get_init_params(locals())
예제 #3
0
    def __init__(
        self,
        adata: AnnData,
        n_latent: int = 20,
        gene_dispersion: Literal[
            "gene", "gene-batch", "gene-label", "gene-cell"
        ] = "gene",
        protein_dispersion: Literal[
            "protein", "protein-batch", "protein-label"
        ] = "protein",
        gene_likelihood: Literal["zinb", "nb"] = "nb",
        latent_distribution: Literal["normal", "ln"] = "normal",
        empirical_protein_background_prior: Optional[bool] = None,
        override_missing_proteins: bool = False,
        **model_kwargs,
    ):
        super(TOTALVI, self).__init__(adata)
        if (
            "totalvi_batch_mask" in self.scvi_setup_dict_.keys()
            and not override_missing_proteins
        ):
            batch_mask = self.scvi_setup_dict_["totalvi_batch_mask"]
            info_msg = (
                "Some proteins have all 0 counts in some batches. "
                + "These proteins will be treated as missing; however, "
                + "this can occur due to experimental design/biology. "
                + "Reinitialize the model with `override_missing_proteins=True`,"
                + "to override this behavior."
            )
            logger.info(info_msg)
        else:
            batch_mask = None
        emp_prior = (
            empirical_protein_background_prior
            if empirical_protein_background_prior is not None
            else (self.summary_stats["n_proteins"] > 10)
        )
        if emp_prior:
            prior_mean, prior_scale = _get_totalvi_protein_priors(adata)
        else:
            prior_mean, prior_scale = None, None

        n_cats_per_cov = (
            self.scvi_setup_dict_["extra_categoricals"]["n_cats_per_key"]
            if "extra_categoricals" in self.scvi_setup_dict_
            else None
        )

        n_batch = self.summary_stats["n_batch"]
        library_log_means, library_log_vars = _init_library_size(adata, n_batch)

        self.module = TOTALVAE(
            n_input_genes=self.summary_stats["n_vars"],
            n_input_proteins=self.summary_stats["n_proteins"],
            n_batch=n_batch,
            n_latent=n_latent,
            n_continuous_cov=self.summary_stats["n_continuous_covs"],
            n_cats_per_cov=n_cats_per_cov,
            gene_dispersion=gene_dispersion,
            protein_dispersion=protein_dispersion,
            gene_likelihood=gene_likelihood,
            latent_distribution=latent_distribution,
            protein_batch_mask=batch_mask,
            protein_background_prior_mean=prior_mean,
            protein_background_prior_scale=prior_scale,
            library_log_means=library_log_means,
            library_log_vars=library_log_vars,
            **model_kwargs,
        )
        self._model_summary_string = (
            "TotalVI Model with the following params: \nn_latent: {}, "
            "gene_dispersion: {}, protein_dispersion: {}, gene_likelihood: {}, latent_distribution: {}"
        ).format(
            n_latent,
            gene_dispersion,
            protein_dispersion,
            gene_likelihood,
            latent_distribution,
        )
        self.init_params_ = self._get_init_params(locals())