def test_data_splitter(): a = synthetic_iid() # test leaving validataion_size empty works ds = DataSplitter(a, train_size=0.4) ds.setup() # check the number of indices _, _, _ = ds.train_dataloader(), ds.val_dataloader(), ds.test_dataloader() n_train_idx = len(ds.train_idx) n_validation_idx = len(ds.val_idx) if ds.val_idx is not None else 0 n_test_idx = len(ds.test_idx) if ds.test_idx is not None else 0 assert n_train_idx + n_validation_idx + n_test_idx == a.n_obs assert np.isclose(n_train_idx / a.n_obs, 0.4) assert np.isclose(n_validation_idx / a.n_obs, 0.6) assert np.isclose(n_test_idx / a.n_obs, 0) # test test size ds = DataSplitter(a, train_size=0.4, validation_size=0.3) ds.setup() # check the number of indices _, _, _ = ds.train_dataloader(), ds.val_dataloader(), ds.test_dataloader() n_train_idx = len(ds.train_idx) n_validation_idx = len(ds.val_idx) if ds.val_idx is not None else 0 n_test_idx = len(ds.test_idx) if ds.test_idx is not None else 0 assert n_train_idx + n_validation_idx + n_test_idx == a.n_obs assert np.isclose(n_train_idx / a.n_obs, 0.4) assert np.isclose(n_validation_idx / a.n_obs, 0.3) assert np.isclose(n_test_idx / a.n_obs, 0.3) # test that 0 < train_size <= 1 with pytest.raises(ValueError): ds = DataSplitter(a, train_size=2) ds.setup() ds.train_dataloader() with pytest.raises(ValueError): ds = DataSplitter(a, train_size=-2) ds.setup() ds.train_dataloader() # test that 0 <= validation_size < 1 with pytest.raises(ValueError): ds = DataSplitter(a, train_size=0.1, validation_size=1) ds.setup() ds.val_dataloader() with pytest.raises(ValueError): ds = DataSplitter(a, train_size=0.1, validation_size=-1) ds.setup() ds.val_dataloader() # test that train_size + validation_size <= 1 with pytest.raises(ValueError): ds = DataSplitter(a, train_size=1, validation_size=0.1) ds.setup() ds.train_dataloader() ds.val_dataloader()
def train( self, max_epochs: int = 200, use_gpu: Optional[Union[str, int, bool]] = None, kappa: int = 5, train_size: float = 0.9, validation_size: Optional[float] = None, batch_size: int = 128, plan_kwargs: Optional[dict] = None, **kwargs, ): """ Train the model. Parameters ---------- max_epochs Number of passes through the dataset. If `None`, defaults to `np.min([round((20000 / n_cells) * 400), 400])` use_gpu Use default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str, e.g., `'cuda:0'`), or use CPU (if False). kappa Scaling parameter for the discriminator loss. train_size Size of training set in the range [0.0, 1.0]. validation_size Size of the test set. If `None`, defaults to 1 - `train_size`. If `train_size + validation_size < 1`, the remaining cells belong to a test set. batch_size Minibatch size to use during training. plan_kwargs Keyword args for model-specific Pytorch Lightning task. Keyword arguments passed to `train()` will overwrite values present in `plan_kwargs`, when appropriate. **kwargs Other keyword args for :class:`~scvi.train.Trainer`. """ gpus, device = parse_use_gpu_arg(use_gpu) self.trainer = Trainer( max_epochs=max_epochs, gpus=gpus, **kwargs, ) self.train_indices_, self.test_indices_, self.validation_indices_ = [], [], [] train_dls, test_dls, val_dls = [], [], [] for i, adm in enumerate(self.adata_managers.values()): ds = DataSplitter( adm, train_size=train_size, validation_size=validation_size, batch_size=batch_size, use_gpu=use_gpu, ) ds.setup() train_dls.append(ds.train_dataloader()) test_dls.append(ds.test_dataloader()) val = ds.val_dataloader() val_dls.append(val) val.mode = i self.train_indices_.append(ds.train_idx) self.test_indices_.append(ds.test_idx) self.validation_indices_.append(ds.val_idx) train_dl = TrainDL(train_dls) plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict() self._training_plan = GIMVITrainingPlan( self.module, adversarial_classifier=True, scale_adversarial_loss=kappa, **plan_kwargs, ) if train_size == 1.0: # circumvent the empty data loader problem if all dataset used for training self.trainer.fit(self._training_plan, train_dl) else: # accepts list of val dataloaders self.trainer.fit(self._training_plan, train_dl, val_dls) try: self.history_ = self.trainer.logger.history except AttributeError: self.history_ = None self.module.eval() self.to_device(device) self.is_trained_ = True