Exemplo n.º 1
0
def test_sampling_zl(save_path):
    cortex_dataset = scvi.data.cortex(save_path=save_path)
    scvi.data.setup_anndata(cortex_dataset, labels_key="cell_type")
    cortex_vae = VAE(
        cortex_dataset.uns["_scvi"]["summary_stats"]["n_genes"],
        cortex_dataset.uns["_scvi"]["summary_stats"]["n_batch"],
    )
    trainer_cortex_vae = UnsupervisedTrainer(
        cortex_vae, cortex_dataset, train_size=0.5, use_cuda=use_cuda
    )
    trainer_cortex_vae.train(n_epochs=2)

    cortex_cls = Classifier(
        (cortex_vae.n_latent + 1),
        n_labels=cortex_dataset.uns["_scvi"]["summary_stats"]["n_labels"],
    )
    trainer_cortex_cls = ClassifierTrainer(
        cortex_cls, cortex_dataset, sampling_model=cortex_vae, sampling_zl=True
    )
    trainer_cortex_cls.train(n_epochs=2)
    trainer_cortex_cls.test_set.accuracy()
Exemplo n.º 2
0
class SCANVI(RNASeqMixin, VAEMixin, BaseModelClass):
    """
    Single-cell annotation using variational inference [Xu19]_.

    Inspired from M1 + M2 model, as described in (https://arxiv.org/pdf/1406.5298.pdf).

    Parameters
    ----------
    adata
        AnnData object that has been registered via :func:`~scvi.data.setup_anndata`.
    unlabeled_category
        Value used for unlabeled cells in `labels_key` used to setup AnnData with scvi.
    pretrained_model
        Instance of SCVI model that has already been trained.
    n_hidden
        Number of nodes per hidden layer.
    n_latent
        Dimensionality of the latent space.
    n_layers
        Number of hidden layers used for encoder and decoder NNs.
    dropout_rate
        Dropout rate for neural networks.
    dispersion
        One of the following:

        * ``'gene'`` - dispersion parameter of NB is constant per gene across cells
        * ``'gene-batch'`` - dispersion can differ between different batches
        * ``'gene-label'`` - dispersion can differ between different labels
        * ``'gene-cell'`` - dispersion can differ for every gene in every cell
    gene_likelihood
        One of:

        * ``'nb'`` - Negative binomial distribution
        * ``'zinb'`` - Zero-inflated negative binomial distribution
        * ``'poisson'`` - Poisson distribution
    use_cuda
        Use the GPU or not.
    **model_kwargs
        Keyword args for :class:`~scvi.core.modules.VAE` and :class:`~scvi.core.modules.SCANVAE`

    Examples
    --------
    >>> adata = anndata.read_h5ad(path_to_anndata)
    >>> scvi.data.setup_anndata(adata, batch_key="batch", labels_key="labels")
    >>> vae = scvi.model.SCANVI(adata, "Unknown")
    >>> vae.train()
    >>> adata.obsm["X_scVI"] = vae.get_latent_representation()
    >>> adata.obs["pred_label"] = vae.predict()
    """
    def __init__(
        self,
        adata: AnnData,
        unlabeled_category: Union[str, int, float],
        pretrained_model: Optional[SCVI] = None,
        n_hidden: int = 128,
        n_latent: int = 10,
        n_layers: int = 1,
        dropout_rate: float = 0.1,
        dispersion: Literal["gene", "gene-batch", "gene-label",
                            "gene-cell"] = "gene",
        gene_likelihood: Literal["zinb", "nb", "poisson"] = "zinb",
        use_cuda: bool = True,
        **model_kwargs,
    ):
        super(SCANVI, self).__init__(adata, use_cuda=use_cuda)
        self.unlabeled_category = unlabeled_category

        if pretrained_model is not None:
            if pretrained_model.is_trained is False:
                raise ValueError("pretrained model has not been trained")
            self._base_model = pretrained_model.model
            self._is_trained_base = True
        else:
            self._base_model = VAE(
                n_input=self.summary_stats["n_vars"],
                n_batch=self.summary_stats["n_batch"],
                n_hidden=n_hidden,
                n_latent=n_latent,
                n_layers=n_layers,
                dropout_rate=dropout_rate,
                dispersion=dispersion,
                gene_likelihood=gene_likelihood,
                **model_kwargs,
            )
            self._is_trained_base = False
        self.model = SCANVAE(
            n_input=self.summary_stats["n_vars"],
            n_batch=self.summary_stats["n_batch"],
            n_labels=self.summary_stats["n_labels"],
            n_hidden=n_hidden,
            n_latent=n_latent,
            n_layers=n_layers,
            dropout_rate=dropout_rate,
            dispersion=dispersion,
            gene_likelihood=gene_likelihood,
            **model_kwargs,
        )

        # get indices for labeled and unlabeled cells
        key = self.scvi_setup_dict_["data_registry"][
            _CONSTANTS.LABELS_KEY]["attr_key"]
        self._label_mapping = self.scvi_setup_dict_["categorical_mappings"][
            key]["mapping"]
        original_key = self.scvi_setup_dict_["categorical_mappings"][key][
            "original_key"]
        labels = np.asarray(self.adata.obs[original_key]).ravel()
        self._code_to_label = {i: l for i, l in enumerate(self._label_mapping)}
        self._unlabeled_indices = np.argwhere(
            labels == self.unlabeled_category).ravel()
        self._labeled_indices = np.argwhere(
            labels != self.unlabeled_category).ravel()
        self.unsupervised_history_ = None
        self.semisupervised_history_ = None

        self._model_summary_string = (
            "ScanVI Model with params: \nunlabeled_category: {}, n_hidden: {}, n_latent: {}"
            ", n_layers: {}, dropout_rate: {}, dispersion: {}, gene_likelihood: {}"
        ).format(
            unlabeled_category,
            n_hidden,
            n_latent,
            n_layers,
            dropout_rate,
            dispersion,
            gene_likelihood,
        )
        self.init_params_ = self._get_init_params(locals())

    @property
    def _trainer_class(self):
        return SemiSupervisedTrainer

    @property
    def _scvi_dl_class(self):
        return AnnotationDataLoader

    @property
    def history(self):
        """Returns computed metrics during training."""
        return {
            "unsupervised_trainer_history": self.unsupervised_history_,
            "semisupervised_trainer_history": self.semisupervised_history_,
        }

    def train(
        self,
        n_epochs_unsupervised: Optional[int] = None,
        n_epochs_semisupervised: Optional[int] = None,
        train_size: float = 0.9,
        test_size: float = None,
        lr: float = 1e-3,
        n_epochs_kl_warmup: int = 400,
        n_iter_kl_warmup: Optional[int] = None,
        frequency: Optional[int] = None,
        unsupervised_trainer_kwargs: dict = {},
        semisupervised_trainer_kwargs: dict = {},
        unsupervised_train_kwargs: dict = {},
        semisupervised_train_kwargs: dict = {},
    ):
        """
        Train the model.

        Parameters
        ----------
        n_epochs_unsupervised
            Number of passes through the dataset for unsupervised pre-training.
        n_epochs_semisupervised
            Number of passes through the dataset for semisupervised training.
        train_size
            Size of training set in the range [0.0, 1.0].
        test_size
            Size of the test set. If `None`, defaults to 1 - `train_size`. If
            `train_size + test_size < 1`, the remaining cells belong to a validation set.
        lr
            Learning rate for optimization.
        n_epochs_kl_warmup
            Number of passes through dataset for scaling term on KL divergence to go from 0 to 1.
        n_iter_kl_warmup
            Number of minibatches for scaling term on KL divergence to go from 0 to 1.
            To use, set to not `None` and set `n_epochs_kl_warmup` to `None`.
        frequency
            Frequency with which metrics are computed on the data for train/test/val sets for both
            the unsupervised and semisupervised trainers. If you'd like a different frequency for
            the semisupervised trainer, set frequency in semisupervised_train_kwargs.
        unsupervised_trainer_kwargs
            Other keyword args for :class:`~scvi.core.trainers.UnsupervisedTrainer`.
        semisupervised_trainer_kwargs
            Other keyword args for :class:`~scvi.core.trainers.SemiSupervisedTrainer`.
        semisupervised_train_kwargs
            Keyword args for the train method of :class:`~scvi.core.trainers.SemiSupervisedTrainer`.
        """
        unsupervised_trainer_kwargs = dict(unsupervised_trainer_kwargs)
        semisupervised_trainer_kwargs = dict(semisupervised_trainer_kwargs)
        unsupervised_train_kwargs = dict(unsupervised_train_kwargs)
        semisupervised_train_kwargs = dict(semisupervised_train_kwargs)

        if n_epochs_unsupervised is None:
            n_epochs_unsupervised = np.min(
                [round((20000 / self.adata.shape[0]) * 400), 400])
        if n_epochs_semisupervised is None:
            n_epochs_semisupervised = int(
                np.min([10,
                        np.max([2, round(n_epochs_unsupervised / 3.0)])]))
        logger.info("Training Unsupervised Trainer for {} epochs.".format(
            n_epochs_unsupervised))
        logger.info("Training SemiSupervised Trainer for {} epochs.".format(
            n_epochs_semisupervised))

        if self._is_trained_base is not True:
            self._unsupervised_trainer = UnsupervisedTrainer(
                self._base_model,
                self.adata,
                train_size=train_size,
                test_size=test_size,
                n_iter_kl_warmup=n_iter_kl_warmup,
                n_epochs_kl_warmup=n_epochs_kl_warmup,
                frequency=frequency,
                use_cuda=self.use_cuda,
                **unsupervised_trainer_kwargs,
            )
            self._unsupervised_trainer.train(n_epochs=n_epochs_unsupervised,
                                             lr=lr,
                                             **unsupervised_train_kwargs)
            self.unsupervised_history_ = self._unsupervised_trainer.history
            self._is_trained_base = True

        self.model.load_state_dict(self._base_model.state_dict(), strict=False)

        if "frequency" not in semisupervised_trainer_kwargs and frequency is not None:
            semisupervised_trainer_kwargs["frequency"] = frequency

        self.trainer = SemiSupervisedTrainer(
            self.model,
            self.adata,
            use_cuda=self.use_cuda,
            **semisupervised_trainer_kwargs,
        )
        self.trainer.unlabelled_set = self.trainer.create_scvi_dl(
            indices=self._unlabeled_indices)
        self.trainer.labelled_set = self.trainer.create_scvi_dl(
            indices=self._labeled_indices)
        self.semisupervised_history_ = self.trainer.history
        self.trainer.train(
            n_epochs=n_epochs_semisupervised,
            **semisupervised_train_kwargs,
        )

        self.is_trained_ = True

    def predict(
        self,
        adata: Optional[AnnData] = None,
        indices: Optional[Sequence[int]] = None,
        soft: bool = False,
        batch_size: int = 128,
    ) -> Union[np.ndarray, pd.DataFrame]:
        """
        Return cell label predictions.

        Parameters
        ----------
        adata
            AnnData object that has been registered via :func:`~scvi.data.setup_anndata`.
        indices
            Indices of cells in adata to use. If `None`, all cells are used.
        soft
            Return probabilities for each class label.
        batch_size
            Minibatch size to use.
        """
        adata = self._validate_anndata(adata)

        if indices is None:
            indices = np.arange(adata.n_obs)

        scdl = self._make_scvi_dl(adata=adata,
                                  indices=indices,
                                  batch_size=batch_size)

        _, pred = scdl.sequential().compute_predictions(soft=soft)

        if not soft:
            predictions = []
            for p in pred:
                predictions.append(self._code_to_label[p])

            return np.array(predictions)
        else:
            pred = pd.DataFrame(
                pred,
                columns=self._label_mapping,
                index=adata.obs_names[indices],
            )
            return pred
Exemplo n.º 3
0
    def train(
        self,
        n_epochs_unsupervised: Optional[int] = None,
        n_epochs_semisupervised: Optional[int] = None,
        train_size: float = 0.9,
        test_size: float = None,
        lr: float = 1e-3,
        n_epochs_kl_warmup: int = 400,
        n_iter_kl_warmup: Optional[int] = None,
        frequency: Optional[int] = None,
        unsupervised_trainer_kwargs: dict = {},
        semisupervised_trainer_kwargs: dict = {},
        unsupervised_train_kwargs: dict = {},
        semisupervised_train_kwargs: dict = {},
    ):
        """
        Train the model.

        Parameters
        ----------
        n_epochs_unsupervised
            Number of passes through the dataset for unsupervised pre-training.
        n_epochs_semisupervised
            Number of passes through the dataset for semisupervised training.
        train_size
            Size of training set in the range [0.0, 1.0].
        test_size
            Size of the test set. If `None`, defaults to 1 - `train_size`. If
            `train_size + test_size < 1`, the remaining cells belong to a validation set.
        lr
            Learning rate for optimization.
        n_epochs_kl_warmup
            Number of passes through dataset for scaling term on KL divergence to go from 0 to 1.
        n_iter_kl_warmup
            Number of minibatches for scaling term on KL divergence to go from 0 to 1.
            To use, set to not `None` and set `n_epochs_kl_warmup` to `None`.
        frequency
            Frequency with which metrics are computed on the data for train/test/val sets for both
            the unsupervised and semisupervised trainers. If you'd like a different frequency for
            the semisupervised trainer, set frequency in semisupervised_train_kwargs.
        unsupervised_trainer_kwargs
            Other keyword args for :class:`~scvi.core.trainers.UnsupervisedTrainer`.
        semisupervised_trainer_kwargs
            Other keyword args for :class:`~scvi.core.trainers.SemiSupervisedTrainer`.
        semisupervised_train_kwargs
            Keyword args for the train method of :class:`~scvi.core.trainers.SemiSupervisedTrainer`.
        """
        unsupervised_trainer_kwargs = dict(unsupervised_trainer_kwargs)
        semisupervised_trainer_kwargs = dict(semisupervised_trainer_kwargs)
        unsupervised_train_kwargs = dict(unsupervised_train_kwargs)
        semisupervised_train_kwargs = dict(semisupervised_train_kwargs)

        if n_epochs_unsupervised is None:
            n_epochs_unsupervised = np.min(
                [round((20000 / self.adata.shape[0]) * 400), 400])
        if n_epochs_semisupervised is None:
            n_epochs_semisupervised = int(
                np.min([10,
                        np.max([2, round(n_epochs_unsupervised / 3.0)])]))
        logger.info("Training Unsupervised Trainer for {} epochs.".format(
            n_epochs_unsupervised))
        logger.info("Training SemiSupervised Trainer for {} epochs.".format(
            n_epochs_semisupervised))

        if self._is_trained_base is not True:
            self._unsupervised_trainer = UnsupervisedTrainer(
                self._base_model,
                self.adata,
                train_size=train_size,
                test_size=test_size,
                n_iter_kl_warmup=n_iter_kl_warmup,
                n_epochs_kl_warmup=n_epochs_kl_warmup,
                frequency=frequency,
                use_cuda=self.use_cuda,
                **unsupervised_trainer_kwargs,
            )
            self._unsupervised_trainer.train(n_epochs=n_epochs_unsupervised,
                                             lr=lr,
                                             **unsupervised_train_kwargs)
            self.unsupervised_history_ = self._unsupervised_trainer.history
            self._is_trained_base = True

        self.model.load_state_dict(self._base_model.state_dict(), strict=False)

        if "frequency" not in semisupervised_trainer_kwargs and frequency is not None:
            semisupervised_trainer_kwargs["frequency"] = frequency

        self.trainer = SemiSupervisedTrainer(
            self.model,
            self.adata,
            use_cuda=self.use_cuda,
            **semisupervised_trainer_kwargs,
        )
        self.trainer.unlabelled_set = self.trainer.create_scvi_dl(
            indices=self._unlabeled_indices)
        self.trainer.labelled_set = self.trainer.create_scvi_dl(
            indices=self._labeled_indices)
        self.semisupervised_history_ = self.trainer.history
        self.trainer.train(
            n_epochs=n_epochs_semisupervised,
            **semisupervised_train_kwargs,
        )

        self.is_trained_ = True
Exemplo n.º 4
0
 def loss(self, tensors_all, tensors_labelled=None):
     if self._labelled:
         return SemiSupervisedTrainer.loss(self, tensors_all,
                                           tensors_labelled)
     else:
         return UnsupervisedTrainer.loss(self, tensors_all)
Exemplo n.º 5
0
    def train(
        self,
        n_epochs: Optional[int] = None,
        train_size: float = 0.9,
        test_size: Optional[float] = None,
        lr: float = 1e-3,
        n_epochs_kl_warmup: int = 400,
        n_iter_kl_warmup: Optional[int] = None,
        frequency: Optional[int] = None,
        train_fun_kwargs: dict = {},
        **kwargs,
    ):
        """
        Trains the model using amortized variational inference.

        Parameters
        ----------
        n_epochs
            Number of passes through the dataset.
        train_size
            Size of training set in the range [0.0, 1.0].
        test_size
            Size of the test set. If `None`, defaults to 1 - `train_size`. If
            `train_size + test_size < 1`, the remaining cells belong to a validation set.
        lr
            Learning rate for optimization.
        n_epochs_kl_warmup
            Number of passes through dataset for scaling term on KL divergence to go from 0 to 1.
        n_iter_kl_warmup
            Number of minibatches for scaling term on KL divergence to go from 0 to 1.
            To use, set to not `None` and set `n_epochs_kl_warmup` to `None`.
        frequency
            Frequency with which metrics are computed on the data for train/test/val sets.
        train_fun_kwargs
            Keyword args for the train method of :class:`~scvi.core.trainers.UnsupervisedTrainer`.
        **kwargs
            Other keyword args for :class:`~scvi.core.trainers.UnsupervisedTrainer`.
        """
        train_fun_kwargs = dict(train_fun_kwargs)
        if self.is_trained_ is False:
            self.trainer = UnsupervisedTrainer(
                self.model,
                self.adata,
                train_size=train_size,
                test_size=test_size,
                n_iter_kl_warmup=n_iter_kl_warmup,
                n_epochs_kl_warmup=n_epochs_kl_warmup,
                frequency=frequency,
                use_cuda=self.use_cuda,
                **kwargs,
            )
            self.train_indices_ = self.trainer.train_set.indices
            self.test_indices_ = self.trainer.test_set.indices
            self.validation_indices_ = self.trainer.validation_set.indices
            self.history_ = self.trainer.history
        # for autotune
        if "n_epochs" not in train_fun_kwargs:
            if n_epochs is None:
                n_cells = self.adata.n_obs
                n_epochs = np.min([round((20000 / n_cells) * 400), 400])
            train_fun_kwargs["n_epochs"] = n_epochs
        if "lr" not in train_fun_kwargs:
            train_fun_kwargs["lr"] = lr

        logger.info("Training for {} epochs".format(n_epochs))
        self.trainer.train(**train_fun_kwargs)
        self.is_trained_ = True
Exemplo n.º 6
0
class VAEMixin:
    def train(
        self,
        n_epochs: Optional[int] = None,
        train_size: float = 0.9,
        test_size: Optional[float] = None,
        lr: float = 1e-3,
        n_epochs_kl_warmup: int = 400,
        n_iter_kl_warmup: Optional[int] = None,
        frequency: Optional[int] = None,
        train_fun_kwargs: dict = {},
        **kwargs,
    ):
        """
        Trains the model using amortized variational inference.

        Parameters
        ----------
        n_epochs
            Number of passes through the dataset.
        train_size
            Size of training set in the range [0.0, 1.0].
        test_size
            Size of the test set. If `None`, defaults to 1 - `train_size`. If
            `train_size + test_size < 1`, the remaining cells belong to a validation set.
        lr
            Learning rate for optimization.
        n_epochs_kl_warmup
            Number of passes through dataset for scaling term on KL divergence to go from 0 to 1.
        n_iter_kl_warmup
            Number of minibatches for scaling term on KL divergence to go from 0 to 1.
            To use, set to not `None` and set `n_epochs_kl_warmup` to `None`.
        frequency
            Frequency with which metrics are computed on the data for train/test/val sets.
        train_fun_kwargs
            Keyword args for the train method of :class:`~scvi.core.trainers.UnsupervisedTrainer`.
        **kwargs
            Other keyword args for :class:`~scvi.core.trainers.UnsupervisedTrainer`.
        """
        train_fun_kwargs = dict(train_fun_kwargs)
        if self.is_trained_ is False:
            self.trainer = UnsupervisedTrainer(
                self.model,
                self.adata,
                train_size=train_size,
                test_size=test_size,
                n_iter_kl_warmup=n_iter_kl_warmup,
                n_epochs_kl_warmup=n_epochs_kl_warmup,
                frequency=frequency,
                use_cuda=self.use_cuda,
                **kwargs,
            )
            self.train_indices_ = self.trainer.train_set.indices
            self.test_indices_ = self.trainer.test_set.indices
            self.validation_indices_ = self.trainer.validation_set.indices
            self.history_ = self.trainer.history
        # for autotune
        if "n_epochs" not in train_fun_kwargs:
            if n_epochs is None:
                n_cells = self.adata.n_obs
                n_epochs = np.min([round((20000 / n_cells) * 400), 400])
            train_fun_kwargs["n_epochs"] = n_epochs
        if "lr" not in train_fun_kwargs:
            train_fun_kwargs["lr"] = lr

        logger.info("Training for {} epochs".format(n_epochs))
        self.trainer.train(**train_fun_kwargs)
        self.is_trained_ = True

    @torch.no_grad()
    def get_elbo(
        self,
        adata: Optional[AnnData] = None,
        indices: Optional[Sequence[int]] = None,
        batch_size: Optional[int] = None,
    ) -> float:
        """
        Return the ELBO for the data.

        The ELBO is a lower bound on the log likelihood of the data used for optimization
        of VAEs. Note, this is not the negative ELBO, higher is better.

        Parameters
        ----------
        adata
            AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
            AnnData object used to initialize the model.
        indices
            Indices of cells in adata to use. If `None`, all cells are used.
        batch_size
            Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
        """
        adata = self._validate_anndata(adata)
        scdl = self._make_scvi_dl(adata=adata,
                                  indices=indices,
                                  batch_size=batch_size)

        return -scdl.elbo()

    @torch.no_grad()
    def get_marginal_ll(
        self,
        adata: Optional[AnnData] = None,
        indices: Optional[Sequence[int]] = None,
        n_mc_samples: int = 1000,
        batch_size: Optional[int] = None,
    ) -> float:
        """
        Return the marginal LL for the data.

        The computation here is a biased estimator of the marginal log likelihood of the data.
        Note, this is not the negative log likelihood, higher is better.

        Parameters
        ----------
        adata
            AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
            AnnData object used to initialize the model.
        indices
            Indices of cells in adata to use. If `None`, all cells are used.
        n_mc_samples
            Number of Monte Carlo samples to use for marginal LL estimation.
        batch_size
            Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
        """
        adata = self._validate_anndata(adata)
        scdl = self._make_scvi_dl(adata=adata,
                                  indices=indices,
                                  batch_size=batch_size)

        return -scdl.marginal_ll(n_mc_samples=n_mc_samples)

    @torch.no_grad()
    def get_reconstruction_error(
        self,
        adata: Optional[AnnData] = None,
        indices: Optional[Sequence[int]] = None,
        batch_size: Optional[int] = None,
    ) -> float:
        r"""
        Return the reconstruction error for the data.

        This is typically written as :math:`p(x \mid z)`, the likelihood term given one posterior sample.
        Note, this is not the negative likelihood, higher is better.

        Parameters
        ----------
        adata
            AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
            AnnData object used to initialize the model.
        indices
            Indices of cells in adata to use. If `None`, all cells are used.
        batch_size
            Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
        """
        adata = self._validate_anndata(adata)
        scdl = self._make_scvi_dl(adata=adata,
                                  indices=indices,
                                  batch_size=batch_size)

        return -scdl.reconstruction_error()

    @torch.no_grad()
    def get_latent_representation(
        self,
        adata: Optional[AnnData] = None,
        indices: Optional[Sequence[int]] = None,
        give_mean: bool = True,
        mc_samples: int = 5000,
        batch_size: Optional[int] = None,
    ) -> np.ndarray:
        r"""
        Return the latent representation for each cell.

        This is denoted as :math:`z_n` in our manuscripts.

        Parameters
        ----------
        adata
            AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
            AnnData object used to initialize the model.
        indices
            Indices of cells in adata to use. If `None`, all cells are used.
        give_mean
            Give mean of distribution or sample from it.
        mc_samples
            For distributions with no closed-form mean (e.g., `logistic normal`), how many Monte Carlo
            samples to take for computing mean.
        batch_size
            Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.

        Returns
        -------
        latent_representation : np.ndarray
            Low-dimensional representation for each cell
        """
        if self.is_trained_ is False:
            raise RuntimeError("Please train the model first.")

        adata = self._validate_anndata(adata)
        scdl = self._make_scvi_dl(adata=adata,
                                  indices=indices,
                                  batch_size=batch_size)
        latent = []
        for tensors in scdl:
            x = tensors[_CONSTANTS.X_KEY]
            b = tensors[_CONSTANTS.BATCH_KEY]
            z = self.model.sample_from_posterior_z(x,
                                                   b,
                                                   give_mean=give_mean,
                                                   n_samples=mc_samples)
            latent += [z.cpu()]
        return np.array(torch.cat(latent))
Exemplo n.º 7
0
def test_annealing_procedures(save_path):
    cortex_dataset = scvi.data.cortex(save_path=save_path)
    scvi.data.setup_anndata(cortex_dataset, labels_key="cell_type")

    cortex_vae = VAE(
        cortex_dataset.uns["_scvi"]["summary_stats"]["n_genes"],
        cortex_dataset.uns["_scvi"]["summary_stats"]["n_batch"],
    )

    trainer_cortex_vae = UnsupervisedTrainer(
        cortex_vae,
        cortex_dataset,
        train_size=0.5,
        use_cuda=use_cuda,
        n_epochs_kl_warmup=1,
    )
    trainer_cortex_vae.train(n_epochs=2)
    assert trainer_cortex_vae.kl_weight >= 0.99, "Annealing should be over"

    trainer_cortex_vae = UnsupervisedTrainer(
        cortex_vae,
        cortex_dataset,
        train_size=0.5,
        use_cuda=use_cuda,
        n_epochs_kl_warmup=5,
    )
    trainer_cortex_vae.train(n_epochs=2)
    assert trainer_cortex_vae.kl_weight <= 0.99, "Annealing should be proceeding"

    # iter
    trainer_cortex_vae = UnsupervisedTrainer(
        cortex_vae,
        cortex_dataset,
        train_size=0.5,
        use_cuda=use_cuda,
        n_iter_kl_warmup=1,
        n_epochs_kl_warmup=None,
    )
    trainer_cortex_vae.train(n_epochs=2)
    assert trainer_cortex_vae.kl_weight >= 0.99, "Annealing should be over"