예제 #1
0
    def __init__(
        self,
        adata: AnnData,
        cell_type_markers: pd.DataFrame,
        size_factor_key: str,
        **model_kwargs,
    ):
        try:
            cell_type_markers = cell_type_markers.loc[adata.var_names]
        except KeyError:
            raise KeyError(
                "Anndata and cell type markers do not contain the same genes."
            )
        super().__init__(adata)

        register_tensor_from_anndata(adata, "_size_factor", "obs", size_factor_key)

        self.n_genes = self.summary_stats["n_vars"]
        self.cell_type_markers = cell_type_markers
        rho = torch.Tensor(cell_type_markers.to_numpy())
        n_cats_per_cov = (
            self.scvi_setup_dict_["extra_categoricals"]["n_cats_per_key"]
            if "extra_categoricals" in self.scvi_setup_dict_
            else None
        )

        x = scvi.data.get_from_registry(adata, _CONSTANTS.X_KEY)
        col_means = np.asarray(np.mean(x, 0)).ravel()  # (g)
        col_means_mu, col_means_std = np.mean(col_means), np.std(col_means)
        col_means_normalized = torch.Tensor((col_means - col_means_mu) / col_means_std)

        # compute basis means for phi - shape (B)
        basis_means = np.linspace(np.min(x), np.max(x), B)  # (B)

        self.module = CellAssignModule(
            n_genes=self.n_genes,
            rho=rho,
            basis_means=basis_means,
            b_g_0=col_means_normalized,
            n_batch=self.summary_stats["n_batch"],
            n_cats_per_cov=n_cats_per_cov,
            n_continuous_cov=self.summary_stats["n_continuous_covs"],
            **model_kwargs,
        )
        self._model_summary_string = (
            "CellAssign Model with params: \nn_genes: {}, n_labels: {}"
        ).format(
            self.n_genes,
            rho.shape[1],
        )
        self.init_params_ = self._get_init_params(locals())
예제 #2
0
    def __init__(
        self,
        adata: AnnData,
        cell_type_markers: pd.DataFrame,
        **model_kwargs,
    ):
        try:
            cell_type_markers = cell_type_markers.loc[adata.var_names]
        except KeyError:
            raise KeyError(
                "Anndata and cell type markers do not contain the same genes.")
        super().__init__(adata)

        self.n_genes = self.summary_stats.n_vars
        self.cell_type_markers = cell_type_markers
        rho = torch.Tensor(cell_type_markers.to_numpy())
        n_cats_per_cov = (self.adata_manager.get_state_registry(
            REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key
                          if REGISTRY_KEYS.CAT_COVS_KEY
                          in self.adata_manager.data_registry else None)

        adata = self._validate_anndata(adata)
        x = self.get_from_registry(adata, REGISTRY_KEYS.X_KEY)
        col_means = np.asarray(np.mean(x, 0)).ravel()  # (g)
        col_means_mu, col_means_std = np.mean(col_means), np.std(col_means)
        col_means_normalized = torch.Tensor(
            (col_means - col_means_mu) / col_means_std)

        # compute basis means for phi - shape (B)
        basis_means = np.linspace(np.min(x), np.max(x), B)  # (B)

        self.module = CellAssignModule(
            n_genes=self.n_genes,
            rho=rho,
            basis_means=basis_means,
            b_g_0=col_means_normalized,
            n_batch=self.summary_stats.n_batch,
            n_cats_per_cov=n_cats_per_cov,
            n_continuous_cov=self.summary_stats.get("n_extra_continuous_covs",
                                                    0),
            **model_kwargs,
        )
        self._model_summary_string = (
            "CellAssign Model with params: \nn_genes: {}, n_labels: {}"
        ).format(
            self.n_genes,
            rho.shape[1],
        )
        self.init_params_ = self._get_init_params(locals())