class SCANVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseModelClass): """ Single-cell annotation using variational inference [Xu21]_. Inspired from M1 + M2 model, as described in (https://arxiv.org/pdf/1406.5298.pdf). Parameters ---------- adata AnnData object that has been registered via :func:`~scvi.data.setup_anndata`. unlabeled_category Value used for unlabeled cells in `labels_key` used to setup AnnData with scvi. n_hidden Number of nodes per hidden layer. n_latent Dimensionality of the latent space. n_layers Number of hidden layers used for encoder and decoder NNs. dropout_rate Dropout rate for neural networks. dispersion One of the following: * ``'gene'`` - dispersion parameter of NB is constant per gene across cells * ``'gene-batch'`` - dispersion can differ between different batches * ``'gene-label'`` - dispersion can differ between different labels * ``'gene-cell'`` - dispersion can differ for every gene in every cell gene_likelihood One of: * ``'nb'`` - Negative binomial distribution * ``'zinb'`` - Zero-inflated negative binomial distribution * ``'poisson'`` - Poisson distribution **model_kwargs Keyword args for :class:`~scvi.module.SCANVAE` Examples -------- >>> adata = anndata.read_h5ad(path_to_anndata) >>> scvi.data.setup_anndata(adata, batch_key="batch", labels_key="labels") >>> vae = scvi.model.SCANVI(adata, "Unknown") >>> vae.train() >>> adata.obsm["X_scVI"] = vae.get_latent_representation() >>> adata.obs["pred_label"] = vae.predict() Notes ----- See further usage examples in the following tutorials: 1. :doc:`/user_guide/notebooks/harmonization` 2. :doc:`/user_guide/notebooks/scarches_scvi_tools` 3. :doc:`/user_guide/notebooks/seed_labeling` """ def __init__( self, adata: AnnData, unlabeled_category: Union[str, int, float], n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, dropout_rate: float = 0.1, dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene", gene_likelihood: Literal["zinb", "nb", "poisson"] = "zinb", **model_kwargs, ): super(SCANVI, self).__init__(adata) scanvae_model_kwargs = dict(model_kwargs) self.unlabeled_category_ = unlabeled_category has_unlabeled = self._set_indices_and_labels() if len(self._labeled_indices) != 0: self._dl_cls = SemiSupervisedDataLoader else: self._dl_cls = AnnDataLoader # ignores unlabeled catgegory n_labels = ( self.summary_stats["n_labels"] - 1 if has_unlabeled else self.summary_stats["n_labels"] ) n_cats_per_cov = ( self.scvi_setup_dict_["extra_categoricals"]["n_cats_per_key"] if "extra_categoricals" in self.scvi_setup_dict_ else None ) self.module = SCANVAE( n_input=self.summary_stats["n_vars"], n_batch=self.summary_stats["n_batch"], n_labels=n_labels, n_continuous_cov=self.summary_stats["n_continuous_covs"], n_cats_per_cov=n_cats_per_cov, n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers, dropout_rate=dropout_rate, dispersion=dispersion, gene_likelihood=gene_likelihood, **scanvae_model_kwargs, ) self.unsupervised_history_ = None self.semisupervised_history_ = None self._model_summary_string = ( "ScanVI Model with the following params: \nunlabeled_category: {}, n_hidden: {}, n_latent: {}" ", n_layers: {}, dropout_rate: {}, dispersion: {}, gene_likelihood: {}" ).format( unlabeled_category, n_hidden, n_latent, n_layers, dropout_rate, dispersion, gene_likelihood, ) self.init_params_ = self._get_init_params(locals()) self.was_pretrained = False @classmethod def from_scvi_model( cls, scvi_model: SCVI, unlabeled_category: str, adata: Optional[AnnData] = None, **scanvi_kwargs, ): """ Initialize scanVI model with weights from pretrained scVI model. Parameters ---------- scvi_model Pretrained scvi model unlabeled_category Value used for unlabeled cells in `labels_key` used to setup AnnData with scvi. adata AnnData object that has been registered via :func:`~scvi.data.setup_anndata`. scanvi_kwargs kwargs for scanVI model """ if scvi_model.is_trained_ is False: warnings.warn("Passed in scvi model hasn't been trained yet.") scanvi_kwargs = dict(scanvi_kwargs) init_params = scvi_model.init_params_ non_kwargs = init_params["non_kwargs"] kwargs = init_params["kwargs"] kwargs = {k: v for (i, j) in kwargs.items() for (k, v) in j.items()} for k, v in {**non_kwargs, **kwargs}.items(): if k in scanvi_kwargs.keys(): warnings.warn( "Ignoring param '{}' as it was already passed in to ".format(k) + "pretrained scvi model with value {}.".format(v) ) del scanvi_kwargs[k] if adata is None: adata = scvi_model.adata scanvi_model = cls( adata, unlabeled_category, **non_kwargs, **kwargs, **scanvi_kwargs ) scvi_state_dict = scvi_model.module.state_dict() scanvi_model.module.load_state_dict(scvi_state_dict, strict=False) scanvi_model.was_pretrained = True return scanvi_model def _set_indices_and_labels(self): """ Set indices and make unlabeled cat as the last cat. Returns ------- True is categories reordered else False """ # get indices for labeled and unlabeled cells key = self.scvi_setup_dict_["data_registry"][_CONSTANTS.LABELS_KEY]["attr_key"] mapping = self.scvi_setup_dict_["categorical_mappings"][key]["mapping"] original_key = self.scvi_setup_dict_["categorical_mappings"][key][ "original_key" ] labels = np.asarray(self.adata.obs[original_key]).ravel() if self.unlabeled_category_ in labels: unlabeled_idx = np.where(mapping == self.unlabeled_category_) unlabeled_idx = unlabeled_idx[0][0] # move unlabeled category to be the last position mapping[unlabeled_idx], mapping[-1] = mapping[-1], mapping[unlabeled_idx] cat_dtype = CategoricalDtype(categories=mapping, ordered=True) # rerun setup for the batch column _make_obs_column_categorical( self.adata, original_key, "_scvi_labels", categorical_dtype=cat_dtype, ) remapped = True else: remapped = False self.scvi_setup_dict_ = self.adata.uns["_scvi"] self._label_mapping = mapping # set unlabeled and labeled indices self._unlabeled_indices = np.argwhere( labels == self.unlabeled_category_ ).ravel() self._labeled_indices = np.argwhere(labels != self.unlabeled_category_).ravel() self._code_to_label = {i: l for i, l in enumerate(self._label_mapping)} self.original_label_key = original_key return remapped def predict( self, adata: Optional[AnnData] = None, indices: Optional[Sequence[int]] = None, soft: bool = False, batch_size: Optional[int] = None, ) -> Union[np.ndarray, pd.DataFrame]: """ Return cell label predictions. Parameters ---------- adata AnnData object that has been registered via :func:`~scvi.data.setup_anndata`. indices Return probabilities for each class label. soft If True, returns per class probabilities batch_size Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. """ adata = self._validate_anndata(adata) if indices is None: indices = np.arange(adata.n_obs) scdl = self._make_data_loader( adata=adata, indices=indices, batch_size=batch_size, ) y_pred = [] for _, tensors in enumerate(scdl): x = tensors[_CONSTANTS.X_KEY] batch = tensors[_CONSTANTS.BATCH_KEY] pred = self.module.classify(x, batch) if not soft: pred = pred.argmax(dim=1) y_pred.append(pred.detach().cpu()) y_pred = torch.cat(y_pred).numpy() if not soft: predictions = [] for p in y_pred: predictions.append(self._code_to_label[p]) return np.array(predictions) else: n_labels = len(pred[0]) pred = pd.DataFrame( y_pred, columns=self._label_mapping[:n_labels], index=adata.obs_names[indices], ) return y_pred def train( self, max_epochs: Optional[int] = None, n_samples_per_label: Optional[float] = None, check_val_every_n_epoch: Optional[int] = None, train_size: float = 0.9, validation_size: Optional[float] = None, batch_size: int = 128, use_gpu: Optional[Union[str, int, bool]] = None, plan_kwargs: Optional[dict] = None, **trainer_kwargs, ): """ Train the model. Parameters ---------- max_epochs Number of passes through the dataset for semisupervised training. n_samples_per_label Number of subsamples for each label class to sample per epoch. By default, there is no label subsampling. check_val_every_n_epoch Frequency with which metrics are computed on the data for validation set for both the unsupervised and semisupervised trainers. If you'd like a different frequency for the semisupervised trainer, set check_val_every_n_epoch in semisupervised_train_kwargs. train_size Size of training set in the range [0.0, 1.0]. validation_size Size of the test set. If `None`, defaults to 1 - `train_size`. If `train_size + validation_size < 1`, the remaining cells belong to a test set. batch_size Minibatch size to use during training. use_gpu Use default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str, e.g., `'cuda:0'`), or use CPU (if False). plan_kwargs Keyword args for :class:`~scvi.train.SemiSupervisedTrainingPlan`. Keyword arguments passed to `train()` will overwrite values present in `plan_kwargs`, when appropriate. **trainer_kwargs Other keyword args for :class:`~scvi.train.Trainer`. """ if max_epochs is None: n_cells = self.adata.n_obs max_epochs = np.min([round((20000 / n_cells) * 400), 400]) if self.was_pretrained: max_epochs = int(np.min([10, np.max([2, round(max_epochs / 3.0)])])) logger.info("Training for {} epochs.".format(max_epochs)) plan_kwargs = {} if plan_kwargs is None else plan_kwargs # if we have labeled cells, we want to subsample labels each epoch sampler_callback = ( [SubSampleLabels()] if len(self._labeled_indices) != 0 else [] ) data_splitter = SemiSupervisedDataSplitter( adata=self.adata, unlabeled_category=self.unlabeled_category_, train_size=train_size, validation_size=validation_size, n_samples_per_label=n_samples_per_label, batch_size=batch_size, use_gpu=use_gpu, ) training_plan = SemiSupervisedTrainingPlan(self.module, **plan_kwargs) if "callbacks" in trainer_kwargs.keys(): trainer_kwargs["callbacks"].concatenate(sampler_callback) else: trainer_kwargs["callbacks"] = sampler_callback runner = TrainRunner( self, training_plan=training_plan, data_splitter=data_splitter, max_epochs=max_epochs, use_gpu=use_gpu, check_val_every_n_epoch=check_val_every_n_epoch, **trainer_kwargs, ) return runner()
def __init__( self, adata: AnnData, unlabeled_category: Union[str, int, float], n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, dropout_rate: float = 0.1, dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene", gene_likelihood: Literal["zinb", "nb", "poisson"] = "zinb", **model_kwargs, ): super(SCANVI, self).__init__(adata) scanvae_model_kwargs = dict(model_kwargs) self.unlabeled_category_ = unlabeled_category has_unlabeled = self._set_indices_and_labels() if len(self._labeled_indices) != 0: self._dl_cls = SemiSupervisedDataLoader else: self._dl_cls = AnnDataLoader # ignores unlabeled catgegory n_labels = ( self.summary_stats["n_labels"] - 1 if has_unlabeled else self.summary_stats["n_labels"] ) n_cats_per_cov = ( self.scvi_setup_dict_["extra_categoricals"]["n_cats_per_key"] if "extra_categoricals" in self.scvi_setup_dict_ else None ) self.module = SCANVAE( n_input=self.summary_stats["n_vars"], n_batch=self.summary_stats["n_batch"], n_labels=n_labels, n_continuous_cov=self.summary_stats["n_continuous_covs"], n_cats_per_cov=n_cats_per_cov, n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers, dropout_rate=dropout_rate, dispersion=dispersion, gene_likelihood=gene_likelihood, **scanvae_model_kwargs, ) self.unsupervised_history_ = None self.semisupervised_history_ = None self._model_summary_string = ( "ScanVI Model with the following params: \nunlabeled_category: {}, n_hidden: {}, n_latent: {}" ", n_layers: {}, dropout_rate: {}, dispersion: {}, gene_likelihood: {}" ).format( unlabeled_category, n_hidden, n_latent, n_layers, dropout_rate, dispersion, gene_likelihood, ) self.init_params_ = self._get_init_params(locals()) self.was_pretrained = False
def __init__( self, adata: AnnData, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, dropout_rate: float = 0.1, dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene", gene_likelihood: Literal["zinb", "nb", "poisson"] = "zinb", **model_kwargs, ): super(SCANVI, self).__init__(adata) scanvae_model_kwargs = dict(model_kwargs) self._set_indices_and_labels() if len(self._labeled_indices) != 0: self._dl_cls = SemiSupervisedDataLoader else: self._dl_cls = AnnDataLoader # ignores unlabeled catgegory n_labels = (self.summary_stats.n_labels - 1 if self.has_unlabeled else self.summary_stats.n_labels) n_cats_per_cov = (self.adata_manager.get_state_registry( REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry else None) n_batch = self.summary_stats.n_batch use_size_factor_key = (REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry) library_log_means, library_log_vars = None, None if not use_size_factor_key: library_log_means, library_log_vars = _init_library_size( self.adata_manager, n_batch) self.module = SCANVAE( n_input=self.summary_stats.n_vars, n_batch=n_batch, n_labels=n_labels, n_continuous_cov=self.summary_stats.get("n_extra_continuous_covs", 0), n_cats_per_cov=n_cats_per_cov, n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers, dropout_rate=dropout_rate, dispersion=dispersion, gene_likelihood=gene_likelihood, use_size_factor_key=use_size_factor_key, library_log_means=library_log_means, library_log_vars=library_log_vars, **scanvae_model_kwargs, ) self.unsupervised_history_ = None self.semisupervised_history_ = None self._model_summary_string = ( "ScanVI Model with the following params: \nunlabeled_category: {}, n_hidden: {}, n_latent: {}" ", n_layers: {}, dropout_rate: {}, dispersion: {}, gene_likelihood: {}" ).format( self.unlabeled_category_, n_hidden, n_latent, n_layers, dropout_rate, dispersion, gene_likelihood, ) self.init_params_ = self._get_init_params(locals()) self.was_pretrained = False