def test_semisupervised_data_splitter(): a = synthetic_iid() ds = SemiSupervisedDataSplitter(a, "asdf") # check the number of indices train_dl, val_dl, test_dl = ds() n_train_idx = len(train_dl.indices) n_validation_idx = len(val_dl.indices) n_test_idx = len(test_dl.indices) assert n_train_idx + n_validation_idx + n_test_idx == a.n_obs assert np.isclose(n_train_idx / a.n_obs, 0.9) assert np.isclose(n_validation_idx / a.n_obs, 0.1) assert np.isclose(n_test_idx / a.n_obs, 0) # test mix of labeled and unlabeled data unknown_label = "label_0" ds = SemiSupervisedDataSplitter(a, unknown_label) train_dl, val_dl, test_dl = ds() # check the number of indices n_train_idx = len(train_dl.indices) n_validation_idx = len(val_dl.indices) n_test_idx = len(test_dl.indices) assert n_train_idx + n_validation_idx + n_test_idx == a.n_obs assert np.isclose(n_train_idx / a.n_obs, 0.9, rtol=0.05) assert np.isclose(n_validation_idx / a.n_obs, 0.1, rtol=0.05) assert np.isclose(n_test_idx / a.n_obs, 0, rtol=0.05) # check that training indices have proper mix of labeled and unlabeled data labelled_idx = np.where(a.obs["labels"] != unknown_label)[0] unlabelled_idx = np.where(a.obs["labels"] == unknown_label)[0] # labeled training idx labeled_train_idx = [i for i in train_dl.indices if i in labelled_idx] # unlabeled training idx unlabeled_train_idx = [i for i in train_dl.indices if i in unlabelled_idx] n_labeled_idx = len(labelled_idx) n_unlabeled_idx = len(unlabelled_idx) # labeled vs unlabeled ratio in adata adata_ratio = n_unlabeled_idx / n_labeled_idx # labeled vs unlabeled ratio in train set train_ratio = len(unlabeled_train_idx) / len(labeled_train_idx) assert np.isclose(adata_ratio, train_ratio, atol=0.05)
def train( self, max_epochs: Optional[int] = None, n_samples_per_label: Optional[float] = None, check_val_every_n_epoch: Optional[int] = None, train_size: float = 0.9, validation_size: Optional[float] = None, batch_size: int = 128, use_gpu: Optional[Union[str, int, bool]] = None, plan_kwargs: Optional[dict] = None, **trainer_kwargs, ): """ Train the model. Parameters ---------- max_epochs Number of passes through the dataset for semisupervised training. n_samples_per_label Number of subsamples for each label class to sample per epoch. By default, there is no label subsampling. check_val_every_n_epoch Frequency with which metrics are computed on the data for validation set for both the unsupervised and semisupervised trainers. If you'd like a different frequency for the semisupervised trainer, set check_val_every_n_epoch in semisupervised_train_kwargs. train_size Size of training set in the range [0.0, 1.0]. validation_size Size of the test set. If `None`, defaults to 1 - `train_size`. If `train_size + validation_size < 1`, the remaining cells belong to a test set. batch_size Minibatch size to use during training. use_gpu Use default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str, e.g., `'cuda:0'`), or use CPU (if False). plan_kwargs Keyword args for :class:`~scvi.train.SemiSupervisedTrainingPlan`. Keyword arguments passed to `train()` will overwrite values present in `plan_kwargs`, when appropriate. **trainer_kwargs Other keyword args for :class:`~scvi.train.Trainer`. """ if max_epochs is None: n_cells = self.adata.n_obs max_epochs = np.min([round((20000 / n_cells) * 400), 400]) if self.was_pretrained: max_epochs = int(np.min([10, np.max([2, round(max_epochs / 3.0)])])) logger.info("Training for {} epochs.".format(max_epochs)) plan_kwargs = {} if plan_kwargs is None else plan_kwargs # if we have labeled cells, we want to subsample labels each epoch sampler_callback = ( [SubSampleLabels()] if len(self._labeled_indices) != 0 else [] ) data_splitter = SemiSupervisedDataSplitter( adata=self.adata, unlabeled_category=self.unlabeled_category_, train_size=train_size, validation_size=validation_size, n_samples_per_label=n_samples_per_label, batch_size=batch_size, use_gpu=use_gpu, ) training_plan = SemiSupervisedTrainingPlan(self.module, **plan_kwargs) if "callbacks" in trainer_kwargs.keys(): trainer_kwargs["callbacks"].concatenate(sampler_callback) else: trainer_kwargs["callbacks"] = sampler_callback runner = TrainRunner( self, training_plan=training_plan, data_splitter=data_splitter, max_epochs=max_epochs, use_gpu=use_gpu, check_val_every_n_epoch=check_val_every_n_epoch, **trainer_kwargs, ) return runner()
def test_semisupervised_data_splitter(): a = synthetic_iid() adata_manager = generic_setup_adata_manager(a, batch_key="batch", labels_key="labels") ds = SemiSupervisedDataSplitter(adata_manager, "asdf") ds.setup() # check the number of indices _, _, _ = ds.train_dataloader(), ds.val_dataloader(), ds.test_dataloader() n_train_idx = len(ds.train_idx) n_validation_idx = len(ds.val_idx) if ds.val_idx is not None else 0 n_test_idx = len(ds.test_idx) if ds.test_idx is not None else 0 assert n_train_idx + n_validation_idx + n_test_idx == a.n_obs assert np.isclose(n_train_idx / a.n_obs, 0.9) assert np.isclose(n_validation_idx / a.n_obs, 0.1) assert np.isclose(n_test_idx / a.n_obs, 0) # test mix of labeled and unlabeled data unknown_label = "label_0" ds = SemiSupervisedDataSplitter(adata_manager, unknown_label) ds.setup() _, _, _ = ds.train_dataloader(), ds.val_dataloader(), ds.test_dataloader() # check the number of indices n_train_idx = len(ds.train_idx) n_validation_idx = len(ds.val_idx) if ds.val_idx is not None else 0 n_test_idx = len(ds.test_idx) if ds.test_idx is not None else 0 assert n_train_idx + n_validation_idx + n_test_idx == a.n_obs assert np.isclose(n_train_idx / a.n_obs, 0.9, rtol=0.05) assert np.isclose(n_validation_idx / a.n_obs, 0.1, rtol=0.05) assert np.isclose(n_test_idx / a.n_obs, 0, rtol=0.05) # check that training indices have proper mix of labeled and unlabeled data labelled_idx = np.where(a.obs["labels"] != unknown_label)[0] unlabelled_idx = np.where(a.obs["labels"] == unknown_label)[0] # labeled training idx labeled_train_idx = [i for i in ds.train_idx if i in labelled_idx] # unlabeled training idx unlabeled_train_idx = [i for i in ds.train_idx if i in unlabelled_idx] n_labeled_idx = len(labelled_idx) n_unlabeled_idx = len(unlabelled_idx) # labeled vs unlabeled ratio in adata adata_ratio = n_unlabeled_idx / n_labeled_idx # labeled vs unlabeled ratio in train set train_ratio = len(unlabeled_train_idx) / len(labeled_train_idx) assert np.isclose(adata_ratio, train_ratio, atol=0.05)