def test_pyro_bayesian_regression_jit(): use_gpu = int(torch.cuda.is_available()) adata = synthetic_iid() train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128) pyro.clear_param_store() model = BayesianRegressionModule(adata.shape[1], 1) train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128) plan = PyroTrainingPlan(model, loss_fn=pyro.infer.JitTrace_ELBO()) trainer = Trainer(gpus=use_gpu, max_epochs=2, callbacks=[PyroJitGuideWarmup(train_dl)]) trainer.fit(plan, train_dl) # 100 features, 1 for sigma, 1 for bias assert list(model.guide.parameters())[0].shape[0] == 102 if use_gpu == 1: model.cuda() # test Predictive num_samples = 5 predictive = model.create_predictive(num_samples=num_samples) for tensor_dict in train_dl: args, kwargs = model._get_fn_args_from_batch(tensor_dict) _ = { k: v.detach().cpu().numpy() for k, v in predictive(*args, **kwargs).items() if k != "obs" }
def test_pyro_bayesian_regression(save_path): use_gpu = int(torch.cuda.is_available()) adata = synthetic_iid() train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128) pyro.clear_param_store() model = BayesianRegressionModule(adata.shape[1], 1) plan = PyroTrainingPlan(model) trainer = Trainer( gpus=use_gpu, max_epochs=2, ) trainer.fit(plan, train_dl) if use_gpu == 1: model.cuda() # test Predictive num_samples = 5 predictive = model.create_predictive(num_samples=num_samples) for tensor_dict in train_dl: args, kwargs = model._get_fn_args_from_batch(tensor_dict) _ = { k: v.detach().cpu().numpy() for k, v in predictive(*args, **kwargs).items() if k != "obs" } # test save and load # cpu/gpu has minor difference model.cpu() quants = model.guide.quantiles([0.5]) sigma_median = quants["sigma"][0].detach().cpu().numpy() linear_median = quants["linear.weight"][0].detach().cpu().numpy() model_save_path = os.path.join(save_path, "model_params.pt") torch.save(model.state_dict(), model_save_path) pyro.clear_param_store() new_model = BayesianRegressionModule(adata.shape[1], 1) # run model one step to get autoguide params try: new_model.load_state_dict(torch.load(model_save_path)) except RuntimeError as err: if isinstance(new_model, PyroBaseModuleClass): plan = PyroTrainingPlan(new_model) trainer = Trainer( gpus=use_gpu, max_steps=1, ) trainer.fit(plan, train_dl) new_model.load_state_dict(torch.load(model_save_path)) else: raise err quants = new_model.guide.quantiles([0.5]) sigma_median_new = quants["sigma"][0].detach().cpu().numpy() linear_median_new = quants["linear.weight"][0].detach().cpu().numpy() np.testing.assert_array_equal(sigma_median_new, sigma_median) np.testing.assert_array_equal(linear_median_new, linear_median)
def test_pyro_bayesian_regression_jit(): use_gpu = 0 adata = synthetic_iid() train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128) pyro.clear_param_store() model = BayesianRegressionModule(adata.shape[1], 1) # warmup guide for JIT for tensors in train_dl: args, kwargs = model._get_fn_args_from_batch(tensors) model.guide(*args, **kwargs) break train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128) plan = PyroTrainingPlan(model, loss_fn=pyro.infer.JitTrace_ELBO()) trainer = Trainer( gpus=use_gpu, max_epochs=2, ) trainer.fit(plan, train_dl)
def __init__( self, model: BaseModelClass, training_plan: pl.LightningModule, data_splitter: Union[SemiSupervisedDataSplitter, DataSplitter], max_epochs: int, use_gpu: Optional[Union[str, int, bool]] = None, **trainer_kwargs, ): self.training_plan = training_plan self.data_splitter = data_splitter self.model = model gpus, device = parse_use_gpu_arg(use_gpu) self.gpus = gpus self.device = device self.trainer = Trainer(max_epochs=max_epochs, gpus=gpus, **trainer_kwargs)
def test_pyro_bayesian_regression(save_path): use_gpu = 0 adata = synthetic_iid() train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128) pyro.clear_param_store() model = BayesianRegressionModule(adata.shape[1], 1) plan = PyroTrainingPlan(model) trainer = Trainer( gpus=use_gpu, max_epochs=2, ) trainer.fit(plan, train_dl) # test save and load post_dl = AnnDataLoader(adata, shuffle=False, batch_size=128) mean1 = [] with torch.no_grad(): for tensors in post_dl: args, kwargs = model._get_fn_args_from_batch(tensors) mean1.append(model(*args, **kwargs).cpu().numpy()) mean1 = np.concatenate(mean1) model_save_path = os.path.join(save_path, "model_params.pt") torch.save(model.state_dict(), model_save_path) pyro.clear_param_store() new_model = BayesianRegressionModule(adata.shape[1], 1) # run model one step to get autoguide params try: new_model.load_state_dict(torch.load(model_save_path)) except RuntimeError as err: if isinstance(new_model, PyroBaseModuleClass): plan = PyroTrainingPlan(new_model) trainer = Trainer( gpus=use_gpu, max_steps=1, ) trainer.fit(plan, train_dl) new_model.load_state_dict(torch.load(model_save_path)) else: raise err mean2 = [] with torch.no_grad(): for tensors in post_dl: args, kwargs = new_model._get_fn_args_from_batch(tensors) mean2.append(new_model(*args, **kwargs).cpu().numpy()) mean2 = np.concatenate(mean2) np.testing.assert_array_equal(mean1, mean2)
class SCANVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseModelClass): """ Single-cell annotation using variational inference [Xu20]_. Inspired from M1 + M2 model, as described in (https://arxiv.org/pdf/1406.5298.pdf). Parameters ---------- adata AnnData object that has been registered via :func:`~scvi.data.setup_anndata`. unlabeled_category Value used for unlabeled cells in `labels_key` used to setup AnnData with scvi. n_hidden Number of nodes per hidden layer. n_latent Dimensionality of the latent space. n_layers Number of hidden layers used for encoder and decoder NNs. dropout_rate Dropout rate for neural networks. dispersion One of the following: * ``'gene'`` - dispersion parameter of NB is constant per gene across cells * ``'gene-batch'`` - dispersion can differ between different batches * ``'gene-label'`` - dispersion can differ between different labels * ``'gene-cell'`` - dispersion can differ for every gene in every cell gene_likelihood One of: * ``'nb'`` - Negative binomial distribution * ``'zinb'`` - Zero-inflated negative binomial distribution * ``'poisson'`` - Poisson distribution use_gpu Use the GPU or not. **model_kwargs Keyword args for :class:`~scvi.modules.SCANVAE` Examples -------- >>> adata = anndata.read_h5ad(path_to_anndata) >>> scvi.data.setup_anndata(adata, batch_key="batch", labels_key="labels") >>> vae = scvi.model.SCANVI(adata, "Unknown") >>> vae.train() >>> adata.obsm["X_scVI"] = vae.get_latent_representation() >>> adata.obs["pred_label"] = vae.predict() """ def __init__( self, adata: AnnData, unlabeled_category: Union[str, int, float], n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, dropout_rate: float = 0.1, dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene", gene_likelihood: Literal["zinb", "nb", "poisson"] = "zinb", use_gpu: bool = True, **model_kwargs, ): super(SCANVI, self).__init__(adata, use_gpu=use_gpu) scanvae_model_kwargs = dict(model_kwargs) self.unlabeled_category_ = unlabeled_category has_unlabeled = self._set_indices_and_labels() if len(self._labeled_indices) != 0: self._dl_cls = SemiSupervisedDataLoader else: self._dl_cls = ScviDataLoader # ignores unlabeled catgegory n_labels = (self.summary_stats["n_labels"] - 1 if has_unlabeled else self.summary_stats["n_labels"]) n_cats_per_cov = ( self.scvi_setup_dict_["extra_categoricals"]["n_cats_per_key"] if "extra_categoricals" in self.scvi_setup_dict_ else None) self.model = SCANVAE( n_input=self.summary_stats["n_vars"], n_batch=self.summary_stats["n_batch"], n_labels=n_labels, n_continuous_cov=self.summary_stats["n_continuous_covs"], n_cats_per_cov=n_cats_per_cov, n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers, dropout_rate=dropout_rate, dispersion=dispersion, gene_likelihood=gene_likelihood, **scanvae_model_kwargs, ) self.unsupervised_history_ = None self.semisupervised_history_ = None self._model_summary_string = ( "ScanVI Model with the following params: \nunlabeled_category: {}, n_hidden: {}, n_latent: {}" ", n_layers: {}, dropout_rate: {}, dispersion: {}, gene_likelihood: {}" ).format( unlabeled_category, n_hidden, n_latent, n_layers, dropout_rate, dispersion, gene_likelihood, ) self.init_params_ = self._get_init_params(locals()) def _set_indices_and_labels(self): """ Set indices and make unlabeled cat as the last cat. Returns ------- True is categories reordered else False """ # get indices for labeled and unlabeled cells key = self.scvi_setup_dict_["data_registry"][ _CONSTANTS.LABELS_KEY]["attr_key"] mapping = self.scvi_setup_dict_["categorical_mappings"][key]["mapping"] original_key = self.scvi_setup_dict_["categorical_mappings"][key][ "original_key"] labels = np.asarray(self.adata.obs[original_key]).ravel() if self.unlabeled_category_ in labels: unlabeled_idx = np.where(mapping == self.unlabeled_category_) unlabeled_idx = unlabeled_idx[0][0] # move unlabeled category to be the last position mapping[unlabeled_idx], mapping[-1] = mapping[-1], mapping[ unlabeled_idx] cat_dtype = CategoricalDtype(categories=mapping, ordered=True) # rerun setup for the batch column _make_obs_column_categorical( self.adata, original_key, "_scvi_labels", categorical_dtype=cat_dtype, ) remapped = True else: remapped = False self.scvi_setup_dict_ = self.adata.uns["_scvi"] self._label_mapping = mapping # set unlabeled and labeled indices self._unlabeled_indices = np.argwhere( labels == self.unlabeled_category_).ravel() self._labeled_indices = np.argwhere( labels != self.unlabeled_category_).ravel() self._code_to_label = {i: l for i, l in enumerate(self._label_mapping)} self.original_label_key = original_key return remapped @property def _task_class(self): return SemiSupervisedTask @property def _data_loader_cls(self): return ScviDataLoader @property def history(self): """Returns computed metrics during training.""" return self._trainer.logger.history def train( self, max_epochs: Optional[int] = None, n_samples_per_label: Optional[float] = None, check_val_every_n_epoch: Optional[int] = None, train_size: float = 0.9, validation_size: Optional[float] = None, batch_size: int = 128, use_gpu: Optional[bool] = None, vae_task_kwargs: Optional[dict] = None, **kwargs, ): """ Train the model. Parameters ---------- max_epochs Number of passes through the dataset for semisupervised training. n_samples_per_label Number of subsamples for each label class to sample per epoch. By default, there is no label subsampling. check_val_every_n_epoch Frequency with which metrics are computed on the data for validation set for both the unsupervised and semisupervised trainers. If you'd like a different frequency for the semisupervised trainer, set check_val_every_n_epoch in semisupervised_train_kwargs. train_size Size of training set in the range [0.0, 1.0]. validation_size Size of the test set. If `None`, defaults to 1 - `train_size`. If `train_size + validation_size < 1`, the remaining cells belong to a test set. batch_size Minibatch size to use during training. use_gpu If `True`, use the GPU if available. Will override the use_gpu option when initializing model vae_task_kwargs Keyword args for :class:`~scvi.lightning.SemiSupervisedTask`. Keyword arguments passed to `train()` will overwrite values present in `vae_task_kwargs`, when appropriate. **kwargs Other keyword args for :class:`~scvi.lightning.Trainer`. """ if max_epochs is None: n_cells = self.adata.n_obs max_epochs = np.min([round((20000 / n_cells) * 400), 400]) logger.info("Training for {} epochs.".format(max_epochs)) use_gpu = use_gpu if use_gpu is not None else self.use_gpu gpus = 1 if use_gpu else None pin_memory = (True if (settings.dl_pin_memory_gpu_training and use_gpu) else False) train_dl, val_dl, test_dl = self._train_test_val_split( self.adata, unlabeled_category=self.unlabeled_category_, train_size=train_size, validation_size=validation_size, n_samples_per_label=n_samples_per_label, pin_memory=pin_memory, batch_size=batch_size, ) self.train_indices_ = train_dl.indices self.validation_indices_ = val_dl.indices self.test_indices_ = test_dl.indices vae_task_kwargs = {} if vae_task_kwargs is None else vae_task_kwargs self._task = SemiSupervisedTask(self.model, **vae_task_kwargs) # if we have labeled cells, we want to subsample labels each epoch sampler_callback = ([SubSampleLabels()] if len(self._labeled_indices) != 0 else []) self._trainer = Trainer( max_epochs=max_epochs, gpus=gpus, callbacks=sampler_callback, check_val_every_n_epoch=check_val_every_n_epoch, **kwargs, ) if len(self.validation_indices_) != 0: self._trainer.fit(self._task, train_dl, val_dl) else: self._trainer.fit(self._task, train_dl) self.model.eval() self.is_trained_ = True def predict( self, adata: Optional[AnnData] = None, indices: Optional[Sequence[int]] = None, soft: bool = False, batch_size: Optional[int] = None, ) -> Union[np.ndarray, pd.DataFrame]: """ Return cell label predictions. Parameters ---------- adata AnnData object that has been registered via :func:`~scvi.data.setup_anndata`. indices Return probabilities for each class label. soft If True, returns per class probabilities batch_size Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. """ adata = self._validate_anndata(adata) if indices is None: indices = np.arange(adata.n_obs) scdl = self._make_scvi_dl( adata=adata, indices=indices, batch_size=batch_size, ) y_pred = [] for _, tensors in enumerate(scdl): x = tensors[_CONSTANTS.X_KEY] batch = tensors[_CONSTANTS.BATCH_KEY] pred = self.model.classify(x, batch) if not soft: pred = pred.argmax(dim=1) y_pred.append(pred.detach().cpu()) y_pred = np.array(torch.cat(y_pred)) if not soft: predictions = [] for p in y_pred: predictions.append(self._code_to_label[p]) return np.array(predictions) else: n_labels = len(pred[0]) pred = pd.DataFrame( y_pred, columns=self._label_mapping[:n_labels], index=adata.obs_names[indices], ) return y_pred def _train_test_val_split( self, adata: AnnData, unlabeled_category, train_size: float = 0.9, validation_size: Optional[float] = None, n_samples_per_label: Optional[int] = None, **kwargs, ): """ Creates data loaders ``train_set``, ``validation_set``, ``test_set``. If ``train_size + validation_set < 1`` then ``test_set`` is non-empty. The ratio between labeled and unlabeled data in adata will be preserved in the train/test/val sets. Parameters ---------- adata AnnData to split into train/test/val sets unlabeled_category Category to treat as unlabeled train_size float, or None (default is 0.9) validation_size float, or None (default is None) n_samples_per_label Number of subsamples for each label class to sample per epoch **kwargs Keyword args for `_make_scvi_dl()` """ train_size = float(train_size) if train_size > 1.0 or train_size <= 0.0: raise ValueError( "train_size needs to be greater than 0 and less than or equal to 1" ) n_labeled_idx = len(self._labeled_indices) n_unlabeled_idx = len(self._unlabeled_indices) def get_train_val_split(n_samples, test_size, train_size): try: n_train, n_val = _validate_shuffle_split( n_samples, test_size, train_size) except ValueError: if train_size != 1.0 and n_samples != 1: raise ValueError( "Choice of train_size={} and validation_size={} not understood" .format(train_size, test_size)) n_train, n_val = n_samples, 0 return n_train, n_val if n_labeled_idx != 0: n_labeled_train, n_labeled_val = get_train_val_split( n_labeled_idx, validation_size, train_size) labeled_permutation = np.random.choice(self._labeled_indices, len(self._labeled_indices), replace=False) labeled_idx_val = labeled_permutation[:n_labeled_val] labeled_idx_train = labeled_permutation[n_labeled_val:( n_labeled_val + n_labeled_train)] labeled_idx_test = labeled_permutation[(n_labeled_val + n_labeled_train):] else: labeled_idx_test = [] labeled_idx_train = [] labeled_idx_val = [] if n_unlabeled_idx != 0: n_unlabeled_train, n_unlabeled_val = get_train_val_split( n_unlabeled_idx, validation_size, train_size) unlabeled_permutation = np.random.choice( self._unlabeled_indices, len(self._unlabeled_indices)) unlabeled_idx_val = unlabeled_permutation[:n_unlabeled_val] unlabeled_idx_train = unlabeled_permutation[n_unlabeled_val:( n_unlabeled_val + n_unlabeled_train)] unlabeled_idx_test = unlabeled_permutation[(n_unlabeled_val + n_unlabeled_train):] else: unlabeled_idx_train = [] unlabeled_idx_val = [] unlabeled_idx_test = [] indices_train = np.concatenate( (labeled_idx_train, unlabeled_idx_train)) indices_val = np.concatenate((labeled_idx_val, unlabeled_idx_val)) indices_test = np.concatenate((labeled_idx_test, unlabeled_idx_test)) indices_train = indices_train.astype(int) indices_val = indices_val.astype(int) indices_test = indices_test.astype(int) if len(self._labeled_indices) != 0: dataloader_class = SemiSupervisedDataLoader dl_kwargs = { "unlabeled_category": unlabeled_category, "n_samples_per_label": n_samples_per_label, } else: dataloader_class = ScviDataLoader dl_kwargs = {} dl_kwargs.update(kwargs) scanvi_train_dl = self._make_scvi_dl( adata, indices=indices_train, shuffle=True, scvi_dl_class=dataloader_class, **dl_kwargs, ) scanvi_val_dl = self._make_scvi_dl( adata, indices=indices_val, shuffle=True, scvi_dl_class=dataloader_class, **dl_kwargs, ) scanvi_test_dl = self._make_scvi_dl( adata, indices=indices_test, shuffle=True, scvi_dl_class=dataloader_class, **dl_kwargs, ) return scanvi_train_dl, scanvi_val_dl, scanvi_test_dl
def train( self, max_epochs: Optional[int] = None, n_samples_per_label: Optional[float] = None, check_val_every_n_epoch: Optional[int] = None, train_size: float = 0.9, validation_size: Optional[float] = None, batch_size: int = 128, use_gpu: Optional[bool] = None, vae_task_kwargs: Optional[dict] = None, **kwargs, ): """ Train the model. Parameters ---------- max_epochs Number of passes through the dataset for semisupervised training. n_samples_per_label Number of subsamples for each label class to sample per epoch. By default, there is no label subsampling. check_val_every_n_epoch Frequency with which metrics are computed on the data for validation set for both the unsupervised and semisupervised trainers. If you'd like a different frequency for the semisupervised trainer, set check_val_every_n_epoch in semisupervised_train_kwargs. train_size Size of training set in the range [0.0, 1.0]. validation_size Size of the test set. If `None`, defaults to 1 - `train_size`. If `train_size + validation_size < 1`, the remaining cells belong to a test set. batch_size Minibatch size to use during training. use_gpu If `True`, use the GPU if available. Will override the use_gpu option when initializing model vae_task_kwargs Keyword args for :class:`~scvi.lightning.SemiSupervisedTask`. Keyword arguments passed to `train()` will overwrite values present in `vae_task_kwargs`, when appropriate. **kwargs Other keyword args for :class:`~scvi.lightning.Trainer`. """ if max_epochs is None: n_cells = self.adata.n_obs max_epochs = np.min([round((20000 / n_cells) * 400), 400]) logger.info("Training for {} epochs.".format(max_epochs)) use_gpu = use_gpu if use_gpu is not None else self.use_gpu gpus = 1 if use_gpu else None pin_memory = (True if (settings.dl_pin_memory_gpu_training and use_gpu) else False) train_dl, val_dl, test_dl = self._train_test_val_split( self.adata, unlabeled_category=self.unlabeled_category_, train_size=train_size, validation_size=validation_size, n_samples_per_label=n_samples_per_label, pin_memory=pin_memory, batch_size=batch_size, ) self.train_indices_ = train_dl.indices self.validation_indices_ = val_dl.indices self.test_indices_ = test_dl.indices vae_task_kwargs = {} if vae_task_kwargs is None else vae_task_kwargs self._task = SemiSupervisedTask(self.model, **vae_task_kwargs) # if we have labeled cells, we want to subsample labels each epoch sampler_callback = ([SubSampleLabels()] if len(self._labeled_indices) != 0 else []) self._trainer = Trainer( max_epochs=max_epochs, gpus=gpus, callbacks=sampler_callback, check_val_every_n_epoch=check_val_every_n_epoch, **kwargs, ) if len(self.validation_indices_) != 0: self._trainer.fit(self._task, train_dl, val_dl) else: self._trainer.fit(self._task, train_dl) self.model.eval() self.is_trained_ = True
def train( self, max_epochs: Optional[int] = None, use_gpu: Optional[bool] = None, train_size: float = 0.9, validation_size: Optional[float] = None, batch_size: int = 128, plan_kwargs: Optional[dict] = None, plan_class: Optional[None] = None, **kwargs, ): """ Train the model. Parameters ---------- max_epochs Number of passes through the dataset. If `None`, defaults to `np.min([round((20000 / n_cells) * 400), 400])` use_gpu If `True`, use the GPU if available. Will override the use_gpu option when initializing model train_size Size of training set in the range [0.0, 1.0]. validation_size Size of the test set. If `None`, defaults to 1 - `train_size`. If `train_size + validation_size < 1`, the remaining cells belong to a test set. batch_size Minibatch size to use during training. plan_kwargs Keyword args for model-specific Pytorch Lightning task. Keyword arguments passed to `train()` will overwrite values present in `plan_kwargs`, when appropriate. plan_class Optional override to use a specific TrainingPlan-type class. **kwargs Other keyword args for :class:`~scvi.lightning.Trainer`. """ if use_gpu is None: use_gpu = self.use_gpu else: use_gpu = use_gpu and torch.cuda.is_available() gpus = 1 if use_gpu else None pin_memory = (True if (settings.dl_pin_memory_gpu_training and use_gpu) else False) if max_epochs is None: n_cells = self.adata.n_obs max_epochs = np.min([round((20000 / n_cells) * 400), 400]) self.trainer = Trainer( max_epochs=max_epochs, gpus=gpus, **kwargs, ) train_dl, val_dl, test_dl = self._train_test_val_split( self.adata, train_size=train_size, validation_size=validation_size, pin_memory=pin_memory, batch_size=batch_size, ) self.train_indices_ = train_dl.indices self.test_indices_ = test_dl.indices self.validation_indices_ = val_dl.indices if plan_class is None: plan_class = self._plan_class plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict() self._pl_task = plan_class(self.module, len(self.train_indices_), **plan_kwargs) if train_size == 1.0: # circumvent the empty data loader problem if all dataset used for training self.trainer.fit(self._pl_task, train_dl) else: self.trainer.fit(self._pl_task, train_dl, val_dl) try: self.history_ = self.trainer.logger.history except AttributeError: self.history_ = None self.module.eval() if use_gpu: self.module.cuda() self.is_trained_ = True
class BaseModelClass(ABC): def __init__(self, adata: Optional[AnnData] = None, use_gpu: Optional[bool] = None): if adata is not None: if "_scvi" not in adata.uns.keys(): raise ValueError( "Please setup your AnnData with scvi.data.setup_anndata(adata) first" ) self.adata = adata self.scvi_setup_dict_ = adata.uns["_scvi"] self.summary_stats = self.scvi_setup_dict_["summary_stats"] self._validate_anndata(adata, copy_if_view=False) self.is_trained_ = False cuda_avail = torch.cuda.is_available() self.use_gpu = cuda_avail if use_gpu is None else (use_gpu and cuda_avail) self._model_summary_string = "" self.train_indices_ = None self.test_indices_ = None self.validation_indices_ = None self.history_ = None def _make_scvi_dl( self, adata: AnnData, indices: Optional[Sequence[int]] = None, batch_size: Optional[int] = None, shuffle: bool = False, scvi_dl_class=None, **data_loader_kwargs, ): """ Create a AnnDataLoader object for data iteration. Parameters ---------- adata AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the AnnData object used to initialize the model. indices Indices of cells in adata to use. If `None`, all cells are used. batch_size Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. shuffle Whether observations are shuffled each iteration though data_loader_kwargs Kwargs to the class-specific data loader class """ if batch_size is None: batch_size = settings.batch_size if indices is None: indices = np.arange(adata.n_obs) if scvi_dl_class is None: scvi_dl_class = self._data_loader_cls if "num_workers" not in data_loader_kwargs: data_loader_kwargs.update({"num_workers": settings.dl_num_workers}) dl = scvi_dl_class( adata, shuffle=shuffle, indices=indices, batch_size=batch_size, **data_loader_kwargs, ) return dl def _train_test_val_split( self, adata: AnnData, train_size: float = 0.9, validation_size: Optional[float] = None, **kwargs, ): """ Creates data loaders ``train_set``, ``validation_set``, ``test_set``. If ``train_size + validation_set < 1`` then ``test_set`` is non-empty. Parameters ---------- adata Setup AnnData to be split into train, test, validation sets train_size float, or None (default is 0.9) validation_size float, or None (default is None) **kwargs Keyword args for `_make_scvi_dl()` """ train_size = float(train_size) if train_size > 1.0 or train_size <= 0.0: raise ValueError( "train_size needs to be greater than 0 and less than or equal to 1" ) n = len(adata) try: n_train, n_val = _validate_shuffle_split(n, validation_size, train_size) except ValueError: if train_size != 1.0: raise ValueError( "Choice of train_size={} and validation_size={} not understood" .format(train_size, validation_size)) n_train, n_val = n, 0 random_state = np.random.RandomState(seed=settings.seed) permutation = random_state.permutation(n) indices_validation = permutation[:n_val] indices_train = permutation[n_val:(n_val + n_train)] indices_test = permutation[(n_val + n_train):] return ( self._make_scvi_dl(adata, indices=indices_train, shuffle=True, **kwargs), self._make_scvi_dl(adata, indices=indices_validation, shuffle=True, **kwargs), self._make_scvi_dl(adata, indices=indices_test, shuffle=True, **kwargs), ) def _validate_anndata(self, adata: Optional[AnnData] = None, copy_if_view: bool = True): """Validate anndata has been properly registered, transfer if necessary.""" if adata is None: adata = self.adata if adata.is_view: if copy_if_view: logger.info("Received view of anndata, making copy.") adata = adata.copy() else: raise ValueError("Please run `adata = adata.copy()`") if "_scvi" not in adata.uns_keys(): logger.info("Input adata not setup with scvi. " + "attempting to transfer anndata setup") transfer_anndata_setup(self.scvi_setup_dict_, adata) is_nonneg_int = _check_nonnegative_integers( get_from_registry(adata, _CONSTANTS.X_KEY)) if not is_nonneg_int: logger.warning( "Make sure the registered X field in anndata contains unnormalized count data." ) _check_anndata_setup_equivalence(self.scvi_setup_dict_, adata) return adata @property @abstractmethod def _data_loader_cls(self): pass @property @abstractmethod def _plan_class(self): pass @property def is_trained(self): return self.is_trained_ @property def test_indices(self): return self.test_indices_ @property def train_indices(self): return self.train_indices_ @property def validation_indices(self): return self.validation_indices_ @property def history(self): """Returns computed metrics during training.""" return self.history_ def _get_user_attributes(self): """Returns all the self attributes defined in a model class, e.g., self.is_trained_.""" attributes = inspect.getmembers(self, lambda a: not (inspect.isroutine(a))) attributes = [ a for a in attributes if not (a[0].startswith("__") and a[0].endswith("__")) ] attributes = [a for a in attributes if not a[0].startswith("_abc_")] return attributes def _get_init_params(self, locals): """ Returns the model init signiture with associated passed in values. Ignores the inital AnnData. """ init = self.__init__ sig = inspect.signature(init) parameters = sig.parameters.values() init_params = [p.name for p in parameters] all_params = {p: locals[p] for p in locals if p in init_params} all_params = { k: v for (k, v) in all_params.items() if not isinstance(v, AnnData) } # not very efficient but is explicit # seperates variable params (**kwargs) from non variable params into two dicts non_var_params = [ p.name for p in parameters if p.kind != p.VAR_KEYWORD ] non_var_params = { k: v for (k, v) in all_params.items() if k in non_var_params } var_params = [p.name for p in parameters if p.kind == p.VAR_KEYWORD] var_params = {k: v for (k, v) in all_params.items() if k in var_params} user_params = {"kwargs": var_params, "non_kwargs": non_var_params} return user_params def train( self, max_epochs: Optional[int] = None, use_gpu: Optional[bool] = None, train_size: float = 0.9, validation_size: Optional[float] = None, batch_size: int = 128, plan_kwargs: Optional[dict] = None, plan_class: Optional[None] = None, **kwargs, ): """ Train the model. Parameters ---------- max_epochs Number of passes through the dataset. If `None`, defaults to `np.min([round((20000 / n_cells) * 400), 400])` use_gpu If `True`, use the GPU if available. Will override the use_gpu option when initializing model train_size Size of training set in the range [0.0, 1.0]. validation_size Size of the test set. If `None`, defaults to 1 - `train_size`. If `train_size + validation_size < 1`, the remaining cells belong to a test set. batch_size Minibatch size to use during training. plan_kwargs Keyword args for model-specific Pytorch Lightning task. Keyword arguments passed to `train()` will overwrite values present in `plan_kwargs`, when appropriate. plan_class Optional override to use a specific TrainingPlan-type class. **kwargs Other keyword args for :class:`~scvi.lightning.Trainer`. """ if use_gpu is None: use_gpu = self.use_gpu else: use_gpu = use_gpu and torch.cuda.is_available() gpus = 1 if use_gpu else None pin_memory = (True if (settings.dl_pin_memory_gpu_training and use_gpu) else False) if max_epochs is None: n_cells = self.adata.n_obs max_epochs = np.min([round((20000 / n_cells) * 400), 400]) self.trainer = Trainer( max_epochs=max_epochs, gpus=gpus, **kwargs, ) train_dl, val_dl, test_dl = self._train_test_val_split( self.adata, train_size=train_size, validation_size=validation_size, pin_memory=pin_memory, batch_size=batch_size, ) self.train_indices_ = train_dl.indices self.test_indices_ = test_dl.indices self.validation_indices_ = val_dl.indices if plan_class is None: plan_class = self._plan_class plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict() self._pl_task = plan_class(self.module, len(self.train_indices_), **plan_kwargs) if train_size == 1.0: # circumvent the empty data loader problem if all dataset used for training self.trainer.fit(self._pl_task, train_dl) else: self.trainer.fit(self._pl_task, train_dl, val_dl) try: self.history_ = self.trainer.logger.history except AttributeError: self.history_ = None self.module.eval() if use_gpu: self.module.cuda() self.is_trained_ = True def save( self, dir_path: str, overwrite: bool = False, save_anndata: bool = False, **anndata_write_kwargs, ): """ Save the state of the model. Neither the trainer optimizer state nor the trainer history are saved. Model files are not expected to be reproducibly saved and loaded across versions until we reach version 1.0. Parameters ---------- dir_path Path to a directory. overwrite Overwrite existing data or not. If `False` and directory already exists at `dir_path`, error will be raised. save_anndata If True, also saves the anndata anndata_write_kwargs Kwargs for :func:`~anndata.AnnData.write` """ # get all the user attributes user_attributes = self._get_user_attributes() # only save the public attributes with _ at the very end user_attributes = { a[0]: a[1] for a in user_attributes if a[0][-1] == "_" } # save the model state dict and the trainer state dict only if not os.path.exists(dir_path) or overwrite: os.makedirs(dir_path, exist_ok=overwrite) else: raise ValueError( "{} already exists. Please provide an unexisting directory for saving." .format(dir_path)) if save_anndata: self.adata.write(os.path.join(dir_path, "adata.h5ad"), **anndata_write_kwargs) model_save_path = os.path.join(dir_path, "model_params.pt") attr_save_path = os.path.join(dir_path, "attr.pkl") varnames_save_path = os.path.join(dir_path, "var_names.csv") var_names = self.adata.var_names.astype(str) var_names = var_names.to_numpy() np.savetxt(varnames_save_path, var_names, fmt="%s") torch.save(self.module.state_dict(), model_save_path) with open(attr_save_path, "wb") as f: pickle.dump(user_attributes, f) @classmethod def load( cls, dir_path: str, adata: Optional[AnnData] = None, use_gpu: Optional[bool] = None, ): """ Instantiate a model from the saved output. Parameters ---------- dir_path Path to saved outputs. adata AnnData organized in the same way as data used to train model. It is not necessary to run :func:`~scvi.data.setup_anndata`, as AnnData is validated against the saved `scvi` setup dictionary. If None, will check for and load anndata saved with the model. use_gpu Whether to load model on GPU. Returns ------- Model with loaded state dictionaries. Examples -------- >>> vae = SCVI.load(adata, save_path) >>> vae.get_latent_representation() """ load_adata = adata is None if use_gpu is None: use_gpu = torch.cuda.is_available() map_location = torch.device("cpu") if use_gpu is False else None ( scvi_setup_dict, attr_dict, var_names, model_state_dict, new_adata, ) = _load_saved_files(dir_path, load_adata, map_location=map_location) adata = new_adata if new_adata is not None else adata _validate_var_names(adata, var_names) transfer_anndata_setup(scvi_setup_dict, adata) model = _initialize_model(cls, adata, attr_dict, use_gpu) # set saved attrs for loaded model for attr, val in attr_dict.items(): setattr(model, attr, val) model.module.load_state_dict(model_state_dict) if use_gpu: model.module.cuda() model.module.eval() model._validate_anndata(adata) return model def __repr__(self, ): summary_string = self._model_summary_string summary_string += "\nTraining status: {}".format( "Trained" if self.is_trained_ else "Not Trained") rich.print(summary_string) command = "scvi.data.view_anndata_setup(model.adata)" command_len = len(command) print_adata_str = "\n\nTo print summary of associated AnnData, use: " + command text = Text(print_adata_str) text.stylize("dark_violet", len(print_adata_str) - command_len, len(print_adata_str)) console = rich.console.Console() console.print(text) return ""
class TrainRunner: """ TrainRunner calls Trainer.fit() and handles pre and post training procedures. Parameters ---------- model model to train training_plan initialized TrainingPlan data_splitter initialized :class:`~scvi.dataloaders.SemiSupervisedDataSplitter` or :class:`~scvi.dataloaders.DataSplitter` max_epochs max_epochs to train for use_gpu Use default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False). trainer_kwargs Extra kwargs for :class:`~scvi.lightning.Trainer` Examples -------- >>> # Following code should be within a subclass of BaseModelClass >>> data_splitter = DataSplitter(self.adata) >>> training_plan = TrainingPlan(self.module, len(data_splitter.train_idx)) >>> runner = TrainRunner( >>> self, >>> training_plan=trianing_plan, >>> data_splitter=data_splitter, >>> max_epochs=max_epochs) >>> runner() """ def __init__( self, model: BaseModelClass, training_plan: pl.LightningModule, data_splitter: Union[SemiSupervisedDataSplitter, DataSplitter], max_epochs: int, use_gpu: Optional[Union[str, int, bool]] = None, **trainer_kwargs, ): self.training_plan = training_plan self.data_splitter = data_splitter self.model = model gpus, device = parse_use_gpu_arg(use_gpu) self.gpus = gpus self.device = device self.trainer = Trainer(max_epochs=max_epochs, gpus=gpus, **trainer_kwargs) def __call__(self): train_dl, val_dl, test_dl = self.data_splitter() self.model.train_indices = train_dl.indices self.model.test_indices = test_dl.indices self.model.validation_indices = val_dl.indices if len(val_dl.indices) == 0: # circumvent the empty data loader problem if all dataset used for training self.trainer.fit(self.training_plan, train_dl) else: self.trainer.fit(self.training_plan, train_dl, val_dl) try: self.model.history_ = self.trainer.logger.history except AttributeError: self.history_ = None self.model.module.eval() self.model.is_trained_ = True self.model.to_device(self.device) self.model.trainer = self.trainer
class GIMVI(VAEMixin, BaseModelClass): """ Joint VAE for imputing missing genes in spatial data [Lopez19]_. Parameters ---------- adata_seq AnnData object that has been registered via :func:`~scvi.data.setup_anndata` and contains RNA-seq data. adata_spatial AnnData object that has been registered via :func:`~scvi.data.setup_anndata` and contains spatial data. n_hidden Number of nodes per hidden layer. generative_distributions List of generative distribution for adata_seq data and adata_spatial data. model_library_size List of bool of whether to model library size for adata_seq and adata_spatial. n_latent Dimensionality of the latent space. **model_kwargs Keyword args for :class:`~scvi.modules.JVAE` Examples -------- >>> adata_seq = anndata.read_h5ad(path_to_anndata_seq) >>> adata_spatial = anndata.read_h5ad(path_to_anndata_spatial) >>> scvi.data.setup_anndata(adata_seq) >>> scvi.data.setup_anndata(adata_spatial) >>> vae = scvi.model.GIMVI(adata_seq, adata_spatial) >>> vae.train(n_epochs=400) Notes ----- See further usage examples in the following tutorials: 1. :doc:`/user_guide/notebooks/gimvi_tutorial` """ def __init__( self, adata_seq: AnnData, adata_spatial: AnnData, generative_distributions: List = ["zinb", "nb"], model_library_size: List = [True, False], n_latent: int = 10, **model_kwargs, ): super(GIMVI, self).__init__() self.adatas = [adata_seq, adata_spatial] self.scvi_setup_dicts_ = { "seq": adata_seq.uns["_scvi"], "spatial": adata_spatial.uns["_scvi"], } seq_var_names = _get_var_names_from_setup_anndata(adata_seq) spatial_var_names = _get_var_names_from_setup_anndata(adata_spatial) if not set(spatial_var_names) <= set(seq_var_names): raise ValueError("spatial genes needs to be subset of seq genes") spatial_gene_loc = [ np.argwhere(seq_var_names == g)[0] for g in spatial_var_names ] spatial_gene_loc = np.concatenate(spatial_gene_loc) gene_mappings = [slice(None), spatial_gene_loc] sum_stats = [d.uns["_scvi"]["summary_stats"] for d in self.adatas] n_inputs = [s["n_vars"] for s in sum_stats] total_genes = adata_seq.uns["_scvi"]["summary_stats"]["n_vars"] # since we are combining datasets, we need to increment the batch_idx # of one of the datasets adata_seq_n_batches = adata_seq.uns["_scvi"]["summary_stats"][ "n_batch"] adata_spatial.obs["_scvi_batch"] += adata_seq_n_batches n_batches = sum([s["n_batch"] for s in sum_stats]) self.module = JVAE( n_inputs, total_genes, gene_mappings, generative_distributions, model_library_size, n_batch=n_batches, n_latent=n_latent, **model_kwargs, ) self._model_summary_string = ( "GimVI Model with the following params: \nn_latent: {}, n_inputs: {}, n_genes: {}, " + "n_batch: {}, generative distributions: {}").format( n_latent, n_inputs, total_genes, n_batches, generative_distributions) self.init_params_ = self._get_init_params(locals()) def train( self, max_epochs: int = 200, use_gpu: Optional[Union[str, int, bool]] = None, kappa: int = 5, train_size: float = 0.9, validation_size: Optional[float] = None, batch_size: int = 128, plan_kwargs: Optional[dict] = None, **kwargs, ): """ Train the model. Parameters ---------- max_epochs Number of passes through the dataset. If `None`, defaults to `np.min([round((20000 / n_cells) * 400), 400])` use_gpu Use default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False). kappa Scaling parameter for the discriminator loss. train_size Size of training set in the range [0.0, 1.0]. validation_size Size of the test set. If `None`, defaults to 1 - `train_size`. If `train_size + validation_size < 1`, the remaining cells belong to a test set. batch_size Minibatch size to use during training. plan_kwargs Keyword args for model-specific Pytorch Lightning task. Keyword arguments passed to `train()` will overwrite values present in `plan_kwargs`, when appropriate. **kwargs Other keyword args for :class:`~scvi.lightning.Trainer`. """ gpus, device = parse_use_gpu_arg(use_gpu) self.trainer = Trainer( max_epochs=max_epochs, gpus=gpus, **kwargs, ) self.train_indices_, self.test_indices_, self.validation_indices_ = [], [], [] train_dls, test_dls, val_dls = [], [], [] for i, ad in enumerate(self.adatas): train, val, test = DataSplitter( ad, train_size=train_size, validation_size=validation_size, batch_size=batch_size, use_gpu=use_gpu, )() train_dls.append(train) test_dls.append(test) val.mode = i val_dls.append(val) self.train_indices_.append(train.indices) self.test_indices_.append(test.indices) self.validation_indices_.append(val.indices) train_dl = TrainDL(train_dls) plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict() self._training_plan = GIMVITrainingPlan( self.module, len(self.train_indices_), adversarial_classifier=True, scale_adversarial_loss=kappa, **plan_kwargs, ) if train_size == 1.0: # circumvent the empty data loader problem if all dataset used for training self.trainer.fit(self._training_plan, train_dl) else: # accepts list of val dataloaders self.trainer.fit(self._training_plan, train_dl, val_dls) try: self.history_ = self.trainer.logger.history except AttributeError: self.history_ = None self.module.eval() self.to_device(device) self.is_trained_ = True def _make_scvi_dls(self, adatas: List[AnnData] = None, batch_size=128): if adatas is None: adatas = self.adatas post_list = [self._make_data_loader(ad) for ad in adatas] for i, dl in enumerate(post_list): dl.mode = i return post_list @torch.no_grad() def get_latent_representation( self, adatas: List[AnnData] = None, deterministic: bool = True, batch_size: int = 128, ) -> List[np.ndarray]: """ Return the latent space embedding for each dataset. Parameters ---------- adatas List of adata seq and adata spatial. deterministic If true, use the mean of the encoder instead of a Gaussian sample. batch_size Minibatch size for data loading into model. """ if adatas is None: adatas = self.adatas scdls = self._make_scvi_dls(adatas, batch_size=batch_size) self.module.eval() latents = [] for mode, scdl in enumerate(scdls): latent = [] for tensors in scdl: ( sample_batch, local_l_mean, local_l_var, batch_index, label, *_, ) = _unpack_tensors(tensors) latent.append( self.module.sample_from_posterior_z( sample_batch, mode, deterministic=deterministic)) latent = torch.cat(latent).cpu().detach().numpy() latents.append(latent) return latents @torch.no_grad() def get_imputed_values( self, adatas: List[AnnData] = None, deterministic: bool = True, normalized: bool = True, decode_mode: Optional[int] = None, batch_size: int = 128, ) -> List[np.ndarray]: """ Return imputed values for all genes for each dataset. Parameters ---------- adatas List of adata seq and adata spatial deterministic If true, use the mean of the encoder instead of a Gaussian sample for the latent vector. normalized Return imputed normalized values or not. decode_mode If a `decode_mode` is given, use the encoder specific to each dataset as usual but use the decoder of the dataset of id `decode_mode` to impute values. batch_size Minibatch size for data loading into model. """ self.module.eval() if adatas is None: adatas = self.adatas scdls = self._make_scvi_dls(adatas, batch_size=batch_size) imputed_values = [] for mode, scdl in enumerate(scdls): imputed_value = [] for tensors in scdl: ( sample_batch, local_l_mean, local_l_var, batch_index, label, *_, ) = _unpack_tensors(tensors) if normalized: imputed_value.append( self.module.sample_scale( sample_batch, mode, batch_index, label, deterministic=deterministic, decode_mode=decode_mode, )) else: imputed_value.append( self.module.sample_rate( sample_batch, mode, batch_index, label, deterministic=deterministic, decode_mode=decode_mode, )) imputed_value = torch.cat(imputed_value).cpu().detach().numpy() imputed_values.append(imputed_value) return imputed_values def save( self, dir_path: str, overwrite: bool = False, save_anndata: bool = False, **anndata_write_kwargs, ): """ Save the state of the model. Neither the trainer optimizer state nor the trainer history are saved. Model files are not expected to be reproducibly saved and loaded across versions until we reach version 1.0. Parameters ---------- dir_path Path to a directory. overwrite Overwrite existing data or not. If `False` and directory already exists at `dir_path`, error will be raised. save_anndata If True, also saves the anndata anndata_write_kwargs Kwargs for anndata write function """ # get all the user attributes user_attributes = self._get_user_attributes() # only save the public attributes with _ at the very end user_attributes = { a[0]: a[1] for a in user_attributes if a[0][-1] == "_" } # save the model state dict and the trainer state dict only if not os.path.exists(dir_path) or overwrite: os.makedirs(dir_path, exist_ok=overwrite) else: raise ValueError( "{} already exists. Please provide an unexisting directory for saving." .format(dir_path)) if save_anndata: dataset_names = ["seq", "spatial"] for i in range(len(self.adatas)): save_path = os.path.join( dir_path, "adata_{}.h5ad".format(dataset_names[i])) self.adatas[i].write(save_path) varnames_save_path = os.path.join( dir_path, "var_names_{}.csv".format(dataset_names[i])) var_names = self.adatas[i].var_names.astype(str) var_names = var_names.to_numpy() np.savetxt(varnames_save_path, var_names, fmt="%s") model_save_path = os.path.join(dir_path, "model_params.pt") attr_save_path = os.path.join(dir_path, "attr.pkl") torch.save(self.module.state_dict(), model_save_path) with open(attr_save_path, "wb") as f: pickle.dump(user_attributes, f) @classmethod def load( cls, dir_path: str, adata_seq: Optional[AnnData] = None, adata_spatial: Optional[AnnData] = None, use_gpu: Optional[Union[str, int, bool]] = None, ): """ Instantiate a model from the saved output. Parameters ---------- adata_seq AnnData organized in the same way as data used to train model. It is not necessary to run :func:`~scvi.data.setup_anndata`, as AnnData is validated against the saved `scvi` setup dictionary. AnnData must be registered via :func:`~scvi.data.setup_anndata`. adata_spatial AnnData organized in the same way as data used to train model. If None, will check for and load anndata saved with the model. dir_path Path to saved outputs. use_gpu Load model on default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False). Returns ------- Model with loaded state dictionaries. Examples -------- >>> vae = GIMVI.load(adata_seq, adata_spatial, save_path) >>> vae.get_latent_representation() """ model_path = os.path.join(dir_path, "model_params.pt") setup_dict_path = os.path.join(dir_path, "attr.pkl") seq_data_path = os.path.join(dir_path, "adata_seq.h5ad") spatial_data_path = os.path.join(dir_path, "adata_spatial.h5ad") seq_var_names_path = os.path.join(dir_path, "var_names_seq.csv") spatial_var_names_path = os.path.join(dir_path, "var_names_spatial.csv") if adata_seq is None and os.path.exists(seq_data_path): adata_seq = read(seq_data_path) elif adata_seq is None and not os.path.exists(seq_data_path): raise ValueError( "Save path contains no saved anndata and no adata was passed.") if adata_spatial is None and os.path.exists(spatial_data_path): adata_spatial = read(spatial_data_path) elif adata_spatial is None and not os.path.exists(spatial_data_path): raise ValueError( "Save path contains no saved anndata and no adata was passed.") adatas = [adata_seq, adata_spatial] seq_var_names = np.genfromtxt(seq_var_names_path, delimiter=",", dtype=str) spatial_var_names = np.genfromtxt(spatial_var_names_path, delimiter=",", dtype=str) var_names = [seq_var_names, spatial_var_names] for i, adata in enumerate(adatas): saved_var_names = var_names[i] user_var_names = adata.var_names.astype(str) if not np.array_equal(saved_var_names, user_var_names): logger.warning( "var_names for adata passed in does not match var_names of " "adata used to train the model. For valid results, the vars " "need to be the same and in the same order as the adata used to train the model." ) with open(setup_dict_path, "rb") as handle: attr_dict = pickle.load(handle) scvi_setup_dicts = attr_dict.pop("scvi_setup_dicts_") transfer_anndata_setup(scvi_setup_dicts["seq"], adata_seq) transfer_anndata_setup(scvi_setup_dicts["spatial"], adata_spatial) # get the parameters for the class init signiture init_params = attr_dict.pop("init_params_") # new saving and loading, enable backwards compatibility if "non_kwargs" in init_params.keys(): # grab all the parameters execept for kwargs (is a dict) non_kwargs = init_params["non_kwargs"] kwargs = init_params["kwargs"] # expand out kwargs kwargs = { k: v for (i, j) in kwargs.items() for (k, v) in j.items() } else: # grab all the parameters execept for kwargs (is a dict) non_kwargs = { k: v for k, v in init_params.items() if not isinstance(v, dict) } kwargs = { k: v for k, v in init_params.items() if isinstance(v, dict) } kwargs = { k: v for (i, j) in kwargs.items() for (k, v) in j.items() } model = cls(adata_seq, adata_spatial, **non_kwargs, **kwargs) for attr, val in attr_dict.items(): setattr(model, attr, val) _, device = parse_use_gpu_arg(use_gpu) model.module.load_state_dict( torch.load(model_path, map_location=device)) model.module.eval() model.to_device(device) return model
def train( self, max_epochs: int = 200, use_gpu: Optional[Union[str, int, bool]] = None, kappa: int = 5, train_size: float = 0.9, validation_size: Optional[float] = None, batch_size: int = 128, plan_kwargs: Optional[dict] = None, **kwargs, ): """ Train the model. Parameters ---------- max_epochs Number of passes through the dataset. If `None`, defaults to `np.min([round((20000 / n_cells) * 400), 400])` use_gpu Use default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False). kappa Scaling parameter for the discriminator loss. train_size Size of training set in the range [0.0, 1.0]. validation_size Size of the test set. If `None`, defaults to 1 - `train_size`. If `train_size + validation_size < 1`, the remaining cells belong to a test set. batch_size Minibatch size to use during training. plan_kwargs Keyword args for model-specific Pytorch Lightning task. Keyword arguments passed to `train()` will overwrite values present in `plan_kwargs`, when appropriate. **kwargs Other keyword args for :class:`~scvi.lightning.Trainer`. """ gpus, device = parse_use_gpu_arg(use_gpu) self.trainer = Trainer( max_epochs=max_epochs, gpus=gpus, **kwargs, ) self.train_indices_, self.test_indices_, self.validation_indices_ = [], [], [] train_dls, test_dls, val_dls = [], [], [] for i, ad in enumerate(self.adatas): train, val, test = DataSplitter( ad, train_size=train_size, validation_size=validation_size, batch_size=batch_size, use_gpu=use_gpu, )() train_dls.append(train) test_dls.append(test) val.mode = i val_dls.append(val) self.train_indices_.append(train.indices) self.test_indices_.append(test.indices) self.validation_indices_.append(val.indices) train_dl = TrainDL(train_dls) plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict() self._training_plan = GIMVITrainingPlan( self.module, len(self.train_indices_), adversarial_classifier=True, scale_adversarial_loss=kappa, **plan_kwargs, ) if train_size == 1.0: # circumvent the empty data loader problem if all dataset used for training self.trainer.fit(self._training_plan, train_dl) else: # accepts list of val dataloaders self.trainer.fit(self._training_plan, train_dl, val_dls) try: self.history_ = self.trainer.logger.history except AttributeError: self.history_ = None self.module.eval() self.to_device(device) self.is_trained_ = True
def train( self, max_epochs: int = 200, use_gpu: Optional[bool] = None, kappa: int = 5, train_size: float = 0.9, validation_size: Optional[float] = None, batch_size: int = 128, vae_task_kwargs: Optional[dict] = None, task_class: Optional[None] = None, **kwargs, ): """ Train the model. Parameters ---------- max_epochs Number of passes through the dataset. If `None`, defaults to `np.min([round((20000 / n_cells) * 400), 400])` use_gpu If `True`, use the GPU if available. Will override the use_gpu option when initializing model kappa Scaling parameter for the discriminator loss. train_size Size of training set in the range [0.0, 1.0]. validation_size Size of the test set. If `None`, defaults to 1 - `train_size`. If `train_size + validation_size < 1`, the remaining cells belong to a test set. batch_size Minibatch size to use during training. vae_task_kwargs Keyword args for model-specific Pytorch Lightning task. Keyword arguments passed to `train()` will overwrite values present in `vae_task_kwargs`, when appropriate. **kwargs Other keyword args for :class:`~scvi.lightning.Trainer`. """ if use_gpu is None: use_gpu = self.use_gpu else: use_gpu = use_gpu and torch.cuda.is_available() gpus = 1 if use_gpu else None pin_memory = (True if (settings.dl_pin_memory_gpu_training and use_gpu) else False) self.trainer = Trainer( max_epochs=max_epochs, gpus=gpus, **kwargs, ) self.train_indices_, self.test_indices_, self.validation_indices_ = [], [], [] train_dls, test_dls, val_dls = [], [], [] for i, ad in enumerate(self.adatas): train, val, test = self._train_test_val_split( ad, train_size=train_size, validation_size=validation_size, pin_memory=pin_memory, batch_size=batch_size, ) train_dls.append(train) test_dls.append(test) val.mode = i val_dls.append(val) self.train_indices_.append(train.indices) self.test_indices_.append(test.indices) self.validation_indices_.append(val.indices) train_dl = TrainDL(train_dls) task_kwargs = vae_task_kwargs if isinstance(vae_task_kwargs, dict) else dict() self._pl_task = self._task_class( self.model, len(self.train_indices_), adversarial_classifier=True, scale_adversarial_loss=kappa, **task_kwargs, ) if train_size == 1.0: # circumvent the empty data loader problem if all dataset used for training self.trainer.fit(self._pl_task, train_dl) else: # accepts list of val dataloaders self.trainer.fit(self._pl_task, train_dl, val_dls) try: self.history_ = self.trainer.logger.history except AttributeError: self.history_ = None self.model.eval() if use_gpu: self.model.cuda() self.is_trained_ = True