コード例 #1
0
    def __init__(
        self,
        adata: AnnData,
        cell_type_markers: pd.DataFrame,
        size_factor_key: str,
        **model_kwargs,
    ):
        try:
            cell_type_markers = cell_type_markers.loc[adata.var_names]
        except KeyError:
            raise KeyError(
                "Anndata and cell type markers do not contain the same genes."
            )
        super().__init__(adata)

        register_tensor_from_anndata(adata, "_size_factor", "obs", size_factor_key)

        self.n_genes = self.summary_stats["n_vars"]
        self.cell_type_markers = cell_type_markers
        rho = torch.Tensor(cell_type_markers.to_numpy())
        n_cats_per_cov = (
            self.scvi_setup_dict_["extra_categoricals"]["n_cats_per_key"]
            if "extra_categoricals" in self.scvi_setup_dict_
            else None
        )

        x = scvi.data.get_from_registry(adata, _CONSTANTS.X_KEY)
        col_means = np.asarray(np.mean(x, 0)).ravel()  # (g)
        col_means_mu, col_means_std = np.mean(col_means), np.std(col_means)
        col_means_normalized = torch.Tensor((col_means - col_means_mu) / col_means_std)

        # compute basis means for phi - shape (B)
        basis_means = np.linspace(np.min(x), np.max(x), B)  # (B)

        self.module = CellAssignModule(
            n_genes=self.n_genes,
            rho=rho,
            basis_means=basis_means,
            b_g_0=col_means_normalized,
            n_batch=self.summary_stats["n_batch"],
            n_cats_per_cov=n_cats_per_cov,
            n_continuous_cov=self.summary_stats["n_continuous_covs"],
            **model_kwargs,
        )
        self._model_summary_string = (
            "CellAssign Model with params: \nn_genes: {}, n_labels: {}"
        ).format(
            self.n_genes,
            rho.shape[1],
        )
        self.init_params_ = self._get_init_params(locals())
コード例 #2
0
    def __init__(
        self,
        adata: AnnData,
        cell_type_markers: pd.DataFrame,
        **model_kwargs,
    ):
        try:
            cell_type_markers = cell_type_markers.loc[adata.var_names]
        except KeyError:
            raise KeyError(
                "Anndata and cell type markers do not contain the same genes.")
        super().__init__(adata)

        self.n_genes = self.summary_stats.n_vars
        self.cell_type_markers = cell_type_markers
        rho = torch.Tensor(cell_type_markers.to_numpy())
        n_cats_per_cov = (self.adata_manager.get_state_registry(
            REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key
                          if REGISTRY_KEYS.CAT_COVS_KEY
                          in self.adata_manager.data_registry else None)

        adata = self._validate_anndata(adata)
        x = self.get_from_registry(adata, REGISTRY_KEYS.X_KEY)
        col_means = np.asarray(np.mean(x, 0)).ravel()  # (g)
        col_means_mu, col_means_std = np.mean(col_means), np.std(col_means)
        col_means_normalized = torch.Tensor(
            (col_means - col_means_mu) / col_means_std)

        # compute basis means for phi - shape (B)
        basis_means = np.linspace(np.min(x), np.max(x), B)  # (B)

        self.module = CellAssignModule(
            n_genes=self.n_genes,
            rho=rho,
            basis_means=basis_means,
            b_g_0=col_means_normalized,
            n_batch=self.summary_stats.n_batch,
            n_cats_per_cov=n_cats_per_cov,
            n_continuous_cov=self.summary_stats.get("n_extra_continuous_covs",
                                                    0),
            **model_kwargs,
        )
        self._model_summary_string = (
            "CellAssign Model with params: \nn_genes: {}, n_labels: {}"
        ).format(
            self.n_genes,
            rho.shape[1],
        )
        self.init_params_ = self._get_init_params(locals())
コード例 #3
0
ファイル: _model.py プロジェクト: ritika-giri/scvi-tools
class CellAssign(UnsupervisedTrainingMixin, BaseModelClass):
    """
    Reimplementation of CellAssign for reference-based annotation [Zhang19]_.

    Parameters
    ----------
    adata
        single-cell AnnData object that has been registered via :func:`~scvi.data.setup_anndata`.
        The object should be subset to contain the same genes as the cell type marker dataframe.
    cell_type_markers
        Binary marker gene DataFrame of genes by cell types. Gene names corresponding to `adata.var_names`
        should be in DataFrame index, and cell type labels should be the columns.
    size_factor_key
        Key in `adata.obs` with continuous valued size factors.
    **model_kwargs
        Keyword args for :class:`~scvi.external.cellassign.CellAssignModule`

    Examples
    --------
    >>> adata = scvi.data.read_h5ad(path_to_anndata)
    >>> marker_gene_mat = pd.read_csv(path_to_marker_gene_csv)
    >>> bdata = adata[:, adata.var.index.isin(marker_gene_mat.index)].copy()
    >>> scvi.data.setup_anndata(bdata)
    >>> model = CellAssign(bdata, marker_gene_mat, size_factor_key='S')
    >>> model.train()
    >>> predictions = model.predict(bdata)
    """
    def __init__(
        self,
        adata: AnnData,
        cell_type_markers: pd.DataFrame,
        size_factor_key: str,
        **model_kwargs,
    ):
        try:
            cell_type_markers = cell_type_markers.loc[adata.var_names]
        except KeyError:
            raise KeyError(
                "Anndata and cell type markers do not contain the same genes.")
        super().__init__(adata)

        register_tensor_from_anndata(adata, "_size_factor", "obs",
                                     size_factor_key)

        self.n_genes = self.summary_stats["n_vars"]
        self.cell_type_markers = cell_type_markers
        rho = torch.Tensor(cell_type_markers.to_numpy())
        n_cats_per_cov = (
            self.scvi_setup_dict_["extra_categoricals"]["n_cats_per_key"]
            if "extra_categoricals" in self.scvi_setup_dict_ else None)

        x = scvi.data.get_from_registry(adata, _CONSTANTS.X_KEY)
        col_means = np.asarray(np.mean(x, 0)).ravel()  # (g)
        col_means_mu, col_means_std = np.mean(col_means), np.std(col_means)
        col_means_normalized = torch.Tensor(
            (col_means - col_means_mu) / col_means_std)

        # compute basis means for phi - shape (B)
        basis_means = np.linspace(np.min(x), np.max(x), B)  # (B)

        self.module = CellAssignModule(
            n_genes=self.n_genes,
            rho=rho,
            basis_means=basis_means,
            b_g_0=col_means_normalized,
            n_batch=self.summary_stats["n_batch"],
            n_cats_per_cov=n_cats_per_cov,
            n_continuous_cov=self.summary_stats["n_continuous_covs"],
            **model_kwargs,
        )
        self._model_summary_string = (
            "CellAssign Model with params: \nn_genes: {}, n_labels: {}"
        ).format(
            self.n_genes,
            rho.shape[1],
        )
        self.init_params_ = self._get_init_params(locals())

    @torch.no_grad()
    def predict(self) -> pd.DataFrame:
        """Predict soft cell type assignment probability for each cell."""
        adata = self._validate_anndata(None)
        scdl = self._make_data_loader(adata=adata)
        predictions = []
        for tensors in scdl:
            generative_inputs = self.module._get_generative_input(
                tensors, None)
            outputs = self.module.generative(**generative_inputs)
            gamma = outputs["gamma"]
            predictions += [gamma.cpu()]
        return pd.DataFrame(np.array(torch.cat(predictions)),
                            columns=self.cell_type_markers.columns)

    def train(
        self,
        max_epochs: int = 400,
        lr: float = 3e-3,
        use_gpu: Optional[Union[str, int, bool]] = None,
        train_size: float = 0.9,
        validation_size: Optional[float] = None,
        batch_size: int = 1024,
        plan_kwargs: Optional[dict] = None,
        early_stopping: bool = True,
        early_stopping_patience: int = 15,
        early_stopping_min_delta: float = 0.0,
        **kwargs,
    ):
        """
        Trains the model.

        Parameters
        ----------
        max_epochs
            Number of epochs to train for
        lr
            Learning rate for optimization.
        use_gpu
            Use default GPU if available (if None or True), or index of GPU to use (if int),
            or name of GPU (if str), or use CPU (if False).
        train_size
            Size of training set in the range [0.0, 1.0].
        validation_size
            Size of the test set. If `None`, defaults to 1 - `train_size`. If
            `train_size + validation_size < 1`, the remaining cells belong to a test set.
        batch_size
            Minibatch size to use during training.
        plan_kwargs
            Keyword args for :class:`~scvi.train.ClassifierTrainingPlan`. Keyword arguments passed to
        early_stopping
            Adds callback for early stopping on validation_loss
        early_stopping_patience
            Number of times early stopping metric can not improve over early_stopping_min_delta
        early_stopping_min_delta
            Threshold for counting an epoch torwards patience
            `train()` will overwrite values present in `plan_kwargs`, when appropriate.
        **kwargs
            Other keyword args for :class:`~scvi.train.Trainer`.
        """
        update_dict = {"lr": lr, "weight_decay": 1e-10}
        if plan_kwargs is not None:
            plan_kwargs.update(update_dict)
        else:
            plan_kwargs = update_dict

        if "callbacks" in kwargs:
            kwargs["callbacks"] += [ClampCallback()]
        else:
            kwargs["callbacks"] = [ClampCallback()]

        if early_stopping:
            early_stopping_callback = [
                EarlyStopping(
                    monitor="elbo_validation",
                    min_delta=early_stopping_min_delta,
                    patience=early_stopping_patience,
                    mode="min",
                )
            ]
            if "callbacks" in kwargs:
                kwargs["callbacks"] += early_stopping_callback
            else:
                kwargs["callbacks"] = early_stopping_callback
            kwargs["check_val_every_n_epoch"] = 1

        if max_epochs is None:
            n_cells = self.adata.n_obs
            max_epochs = np.min([round((20000 / n_cells) * 400), 400])

        plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict()

        data_splitter = DataSplitter(
            self.adata,
            train_size=train_size,
            validation_size=validation_size,
            batch_size=batch_size,
            use_gpu=use_gpu,
        )
        training_plan = TrainingPlan(self.module, len(data_splitter.train_idx),
                                     **plan_kwargs)
        runner = TrainRunner(
            self,
            training_plan=training_plan,
            data_splitter=data_splitter,
            max_epochs=max_epochs,
            use_gpu=use_gpu,
            **kwargs,
        )
        return runner()