Esempio n. 1
0
 def __init__(
     self,
     st_adata: AnnData,
     cell_type_mapping: np.ndarray,
     decoder_state_dict: OrderedDict,
     px_decoder_state_dict: OrderedDict,
     px_r: np.ndarray,
     n_hidden: int,
     n_latent: int,
     n_layers: int,
     **module_kwargs,
 ):
     super(DestVI, self).__init__(st_adata)
     self.module = MRDeconv(
         n_spots=st_adata.n_obs,
         n_labels=cell_type_mapping.shape[0],
         decoder_state_dict=decoder_state_dict,
         px_decoder_state_dict=px_decoder_state_dict,
         px_r=px_r,
         n_genes=st_adata.n_vars,
         n_latent=n_latent,
         n_layers=n_layers,
         n_hidden=n_hidden,
         **module_kwargs,
     )
     self.cell_type_mapping = cell_type_mapping
     self._model_summary_string = "DestVI Model"
     self.init_params_ = self._get_init_params(locals())
Esempio n. 2
0
class DestVI(UnsupervisedTrainingMixin, BaseModelClass):
    """
    Multi-resolution deconvolution of Spatial Transcriptomics data (DestVI) [Lopez21]_.. Most users will use the alternate constructor (see example).

    Parameters
    ----------
    st_adata
        spatial transcriptomics AnnData object that has been registered via :meth:`~scvi.model.DestVI.setup_anndata`.
    cell_type_mapping
        mapping between numerals and cell type labels
    decoder_state_dict
        state_dict from the decoder of the CondSCVI model
    px_decoder_state_dict
        state_dict from the px_decoder of the CondSCVI model
    px_r
        parameters for the px_r tensor in the CondSCVI model
    n_hidden
        Number of nodes per hidden layer.
    n_latent
        Dimensionality of the latent space.
    n_layers
        Number of hidden layers used for encoder and decoder NNs.
    **module_kwargs
        Keyword args for :class:`~scvi.modules.MRDeconv`

    Examples
    --------
    >>> sc_adata = anndata.read_h5ad(path_to_scRNA_anndata)
    >>> scvi.model.CondSCVI.setup_anndata(sc_adata)
    >>> sc_model = scvi.model.CondSCVI(sc_adata)
    >>> st_adata = anndata.read_h5ad(path_to_ST_anndata)
    >>> DestVI.setup_anndata(st_adata)
    >>> spatial_model = DestVI.from_rna_model(st_adata, sc_model)
    >>> spatial_model.train(max_epochs=2000)
    >>> st_adata.obsm["proportions"] = spatial_model.get_proportions(st_adata)
    >>> gamma = spatial_model.get_gamma(st_adata)

    Notes
    -----
    See further usage examples in the following tutorials:

    1. :doc:`/user_guide/notebooks/DestVI_tutorial`
    """

    def __init__(
        self,
        st_adata: AnnData,
        cell_type_mapping: np.ndarray,
        decoder_state_dict: OrderedDict,
        px_decoder_state_dict: OrderedDict,
        px_r: np.ndarray,
        n_hidden: int,
        n_latent: int,
        n_layers: int,
        **module_kwargs,
    ):
        super(DestVI, self).__init__(st_adata)
        self.module = MRDeconv(
            n_spots=st_adata.n_obs,
            n_labels=cell_type_mapping.shape[0],
            decoder_state_dict=decoder_state_dict,
            px_decoder_state_dict=px_decoder_state_dict,
            px_r=px_r,
            n_genes=st_adata.n_vars,
            n_latent=n_latent,
            n_layers=n_layers,
            n_hidden=n_hidden,
            **module_kwargs,
        )
        self.cell_type_mapping = cell_type_mapping
        self._model_summary_string = "DestVI Model"
        self.init_params_ = self._get_init_params(locals())

    @classmethod
    def from_rna_model(
        cls,
        st_adata: AnnData,
        sc_model: CondSCVI,
        vamp_prior_p: int = 50,
        layer: Optional[str] = None,
        **module_kwargs,
    ):
        """
        Alternate constructor for exploiting a pre-trained model on a RNA-seq dataset.

        Parameters
        ----------
        st_adata
            registed anndata object
        sc_model
            trained CondSCVI model
        vamp_prior_p
            number of mixture parameter for VampPrior calculations
        **model_kwargs
            Keyword args for :class:`~scvi.model.DestVI`
        """
        decoder_state_dict = sc_model.module.decoder.state_dict()
        px_decoder_state_dict = sc_model.module.px_decoder.state_dict()
        px_r = sc_model.module.px_r.detach().cpu().numpy()
        mapping = sc_model.adata_manager.get_state_registry(
            REGISTRY_KEYS.LABELS_KEY
        ).categorical_mapping
        if vamp_prior_p is None:
            mean_vprior = None
            var_vprior = None
        else:
            mean_vprior, var_vprior = sc_model.get_vamp_prior(
                sc_model.adata, p=vamp_prior_p
            )

        cls.setup_anndata(st_adata, layer=layer)
        return cls(
            st_adata,
            mapping,
            decoder_state_dict,
            px_decoder_state_dict,
            px_r,
            sc_model.module.n_hidden,
            sc_model.module.n_latent,
            sc_model.module.n_layers,
            mean_vprior=mean_vprior,
            var_vprior=var_vprior,
            **module_kwargs,
        )

    def get_proportions(
        self,
        keep_noise: bool = False,
        indices: Optional[Sequence[int]] = None,
        batch_size: Optional[int] = None,
    ) -> pd.DataFrame:
        """
        Returns the estimated cell type proportion for the spatial data.

        Shape is n_cells x n_labels OR n_cells x (n_labels + 1) if keep_noise.

        Parameters
        ----------
        keep_noise
            whether to account for the noise term as a standalone cell type in the proportion estimate.
        indices
            Indices of cells in adata to use. Only used if amortization. If `None`, all cells are used.
        batch_size
            Minibatch size for data loading into model. Only used if amortization. Defaults to `scvi.settings.batch_size`.
        """
        self._check_if_trained()

        column_names = self.cell_type_mapping
        index_names = self.adata.obs.index
        if keep_noise:
            column_names = np.append(column_names, "noise_term")

        if self.module.amortization in ["both", "proportion"]:
            stdl = self._make_data_loader(
                adata=self.adata, indices=indices, batch_size=batch_size
            )
            prop_ = []
            for tensors in stdl:
                generative_inputs = self.module._get_generative_input(tensors, None)
                prop_local = self.module.get_proportions(
                    x=generative_inputs["x"], keep_noise=keep_noise
                )
                prop_ += [prop_local.cpu()]
            data = torch.cat(prop_).numpy()
            if indices:
                index_names = index_names[indices]
        else:
            if indices is not None:
                logger.info(
                    "No amortization for proportions, ignoring indices and returning results for the full data"
                )
            data = self.module.get_proportions(keep_noise=keep_noise)

        return pd.DataFrame(
            data=data,
            columns=column_names,
            index=index_names,
        )

    def get_gamma(
        self,
        indices: Optional[Sequence[int]] = None,
        batch_size: Optional[int] = None,
        return_numpy: bool = False,
    ) -> Union[np.ndarray, Dict[str, pd.DataFrame]]:
        """
        Returns the estimated cell-type specific latent space for the spatial data.

        Parameters
        ----------
        indices
            Indices of cells in adata to use. Only used if amortization. If `None`, all cells are used.
        batch_size
            Minibatch size for data loading into model. Only used if amortization. Defaults to `scvi.settings.batch_size`.
        return_numpy
            if activated, will return a numpy array of shape is n_spots x n_latent x n_labels.
        """
        self._check_if_trained()

        column_names = np.arange(self.module.n_latent)
        index_names = self.adata.obs.index

        if self.module.amortization in ["both", "latent"]:
            stdl = self._make_data_loader(
                adata=self.adata, indices=indices, batch_size=batch_size
            )
            gamma_ = []
            for tensors in stdl:
                generative_inputs = self.module._get_generative_input(tensors, None)
                gamma_local = self.module.get_gamma(x=generative_inputs["x"])
                gamma_ += [gamma_local.cpu()]
            data = torch.cat(gamma_, dim=-1).numpy()
            if indices is not None:
                index_names = index_names[indices]
        else:
            if indices is not None:
                logger.info(
                    "No amortization for latent values, ignoring adata and returning results for the full data"
                )
            data = self.module.get_gamma()

        data = np.transpose(data, (2, 0, 1))
        if return_numpy:
            return data
        else:
            res = {}
            for i, ct in enumerate(self.cell_type_mapping):
                res[ct] = pd.DataFrame(
                    data=data[:, :, i], columns=column_names, index=index_names
                )
            return res

    def get_scale_for_ct(
        self,
        label: str,
        indices: Optional[Sequence[int]] = None,
        batch_size: Optional[int] = None,
    ) -> pd.DataFrame:
        r"""
        Return the scaled parameter of the NB for every spot in queried cell types.

        Parameters
        ----------
        label
            cell type of interest
        indices
            Indices of cells in self.adata to use. If `None`, all cells are used.
        batch_size
            Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.

        Returns
        -------
        Pandas dataframe of gene_expression
        """
        self._check_if_trained()

        if label not in self.cell_type_mapping:
            raise ValueError("Unknown cell type")
        y = np.where(label == self.cell_type_mapping)[0][0]

        stdl = self._make_data_loader(
            self.adata, indices=indices, batch_size=batch_size
        )
        scale = []
        for tensors in stdl:
            generative_inputs = self.module._get_generative_input(tensors, None)
            x, ind_x = (
                generative_inputs["x"],
                generative_inputs["ind_x"],
            )
            px_scale = self.module.get_ct_specific_expression(x, ind_x, y)
            scale += [px_scale.cpu()]

        data = torch.cat(scale).numpy()
        column_names = self.adata.var.index
        index_names = self.adata.obs.index
        if indices is not None:
            index_names = index_names[indices]
        return pd.DataFrame(data=data, columns=column_names, index=index_names)

    def train(
        self,
        max_epochs: int = 400,
        lr: float = 0.005,
        use_gpu: Optional[Union[str, int, bool]] = None,
        train_size: float = 1.0,
        validation_size: Optional[float] = None,
        batch_size: int = 128,
        n_epochs_kl_warmup: int = 50,
        plan_kwargs: Optional[dict] = None,
        **kwargs,
    ):
        """
        Trains the model using MAP inference.

        Parameters
        ----------
        max_epochs
            Number of epochs to train for
        lr
            Learning rate for optimization.
        use_gpu
            Use default GPU if available (if None or True), or index of GPU to use (if int),
            or name of GPU (if str, e.g., `'cuda:0'`), or use CPU (if False).
        train_size
            Size of training set in the range [0.0, 1.0].
        validation_size
            Size of the test set. If `None`, defaults to 1 - `train_size`. If
            `train_size + validation_size < 1`, the remaining cells belong to a test set.
        batch_size
            Minibatch size to use during training.
        n_epochs_kl_warmup
            number of epochs needed to reach unit kl weight in the elbo
        plan_kwargs
            Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to
            `train()` will overwrite values present in `plan_kwargs`, when appropriate.
        **kwargs
            Other keyword args for :class:`~scvi.train.Trainer`.
        """
        update_dict = {
            "lr": lr,
            "n_epochs_kl_warmup": n_epochs_kl_warmup,
        }
        if plan_kwargs is not None:
            plan_kwargs.update(update_dict)
        else:
            plan_kwargs = update_dict
        super().train(
            max_epochs=max_epochs,
            use_gpu=use_gpu,
            train_size=train_size,
            validation_size=validation_size,
            batch_size=batch_size,
            plan_kwargs=plan_kwargs,
            **kwargs,
        )

    @classmethod
    @setup_anndata_dsp.dedent
    def setup_anndata(
        cls,
        adata: AnnData,
        layer: Optional[str] = None,
        **kwargs,
    ):
        """
        %(summary)s.

        Parameters
        ----------
        %(param_layer)s
        """
        setup_method_args = cls._get_setup_method_args(**locals())
        # add index for each cell (provided to pyro plate for correct minibatching)
        adata.obs["_indices"] = np.arange(adata.n_obs).astype("int64")
        anndata_fields = [
            LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
            NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"),
        ]
        adata_manager = AnnDataManager(
            fields=anndata_fields, setup_method_args=setup_method_args
        )
        adata_manager.register_fields(adata, **kwargs)
        cls.register_manager(adata_manager)