def train( self, max_epochs: Optional[int] = None, use_gpu: Optional[Union[str, int, bool]] = None, train_size: float = 0.9, validation_size: Optional[float] = None, batch_size: int = 128, plan_kwargs: Optional[dict] = None, **trainer_kwargs, ): """ Train the model. Parameters ---------- max_epochs Number of passes through the dataset. If `None`, defaults to `np.min([round((20000 / n_cells) * 400), 400])` use_gpu Use default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False). train_size Size of training set in the range [0.0, 1.0]. validation_size Size of the test set. If `None`, defaults to 1 - `train_size`. If `train_size + validation_size < 1`, the remaining cells belong to a test set. batch_size Minibatch size to use during training. plan_kwargs Keyword args for :class:`~scvi.lightning.TrainingPlan`. Keyword arguments passed to `train()` will overwrite values present in `plan_kwargs`, when appropriate. **trainer_kwargs Other keyword args for :class:`~scvi.lightning.Trainer`. """ if max_epochs is None: n_cells = self.adata.n_obs max_epochs = np.min([round((20000 / n_cells) * 400), 400]) plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict() data_splitter = DataSplitter( self.adata, train_size=train_size, validation_size=validation_size, batch_size=batch_size, use_gpu=use_gpu, ) training_plan = PyroTrainingPlan(self.module, **plan_kwargs) runner = TrainRunner( self, training_plan=training_plan, data_splitter=data_splitter, max_epochs=max_epochs, use_gpu=use_gpu, **trainer_kwargs, ) return runner()
def _scvi_loader(adata, train_size, batch_size, use_gpu=False): """ SCVI splitter. Returs SCVI loader for train and test set. """ data_splitter = DataSplitter(adata, train_size=train_size, validation_size=1. - train_size, batch_size=batch_size, use_gpu=use_gpu) train_dl, test_dl, _ = data_splitter() return train_dl, test_dl
def test_data_splitter(): a = synthetic_iid() # test leaving validataion_size empty works ds = DataSplitter(a, train_size=0.4) # check the number of indices train_dl, val_dl, test_dl = ds() n_train_idx = len(train_dl.indices) n_validation_idx = len(val_dl.indices) n_test_idx = len(test_dl.indices) assert n_train_idx + n_validation_idx + n_test_idx == a.n_obs assert np.isclose(n_train_idx / a.n_obs, 0.4) assert np.isclose(n_validation_idx / a.n_obs, 0.6) assert np.isclose(n_test_idx / a.n_obs, 0) # test test size ds = DataSplitter(a, train_size=0.4, validation_size=0.3) # check the number of indices train_dl, val_dl, test_dl = ds() n_train_idx = len(train_dl.indices) n_validation_idx = len(val_dl.indices) n_test_idx = len(test_dl.indices) assert n_train_idx + n_validation_idx + n_test_idx == a.n_obs assert np.isclose(n_train_idx / a.n_obs, 0.4) assert np.isclose(n_validation_idx / a.n_obs, 0.3) assert np.isclose(n_test_idx / a.n_obs, 0.3) # test that 0 < train_size <= 1 with pytest.raises(ValueError): ds = DataSplitter(a, train_size=2) ds() with pytest.raises(ValueError): ds = DataSplitter(a, train_size=-2) ds() # test that 0 <= validation_size < 1 with pytest.raises(ValueError): ds = DataSplitter(a, train_size=0.1, validation_size=1) ds() with pytest.raises(ValueError): ds = DataSplitter(a, train_size=0.1, validation_size=-1) ds() # test that train_size + validation_size <= 1 with pytest.raises(ValueError): ds = DataSplitter(a, train_size=1, validation_size=0.1) ds()
def test_data_splitter(): a = synthetic_iid() # test leaving validataion_size empty works ds = DataSplitter(a, train_size=0.4) ds.setup() # check the number of indices _, _, _ = ds.train_dataloader(), ds.val_dataloader(), ds.test_dataloader() n_train_idx = len(ds.train_idx) n_validation_idx = len(ds.val_idx) if ds.val_idx is not None else 0 n_test_idx = len(ds.test_idx) if ds.test_idx is not None else 0 assert n_train_idx + n_validation_idx + n_test_idx == a.n_obs assert np.isclose(n_train_idx / a.n_obs, 0.4) assert np.isclose(n_validation_idx / a.n_obs, 0.6) assert np.isclose(n_test_idx / a.n_obs, 0) # test test size ds = DataSplitter(a, train_size=0.4, validation_size=0.3) ds.setup() # check the number of indices _, _, _ = ds.train_dataloader(), ds.val_dataloader(), ds.test_dataloader() n_train_idx = len(ds.train_idx) n_validation_idx = len(ds.val_idx) if ds.val_idx is not None else 0 n_test_idx = len(ds.test_idx) if ds.test_idx is not None else 0 assert n_train_idx + n_validation_idx + n_test_idx == a.n_obs assert np.isclose(n_train_idx / a.n_obs, 0.4) assert np.isclose(n_validation_idx / a.n_obs, 0.3) assert np.isclose(n_test_idx / a.n_obs, 0.3) # test that 0 < train_size <= 1 with pytest.raises(ValueError): ds = DataSplitter(a, train_size=2) ds.setup() ds.train_dataloader() with pytest.raises(ValueError): ds = DataSplitter(a, train_size=-2) ds.setup() ds.train_dataloader() # test that 0 <= validation_size < 1 with pytest.raises(ValueError): ds = DataSplitter(a, train_size=0.1, validation_size=1) ds.setup() ds.val_dataloader() with pytest.raises(ValueError): ds = DataSplitter(a, train_size=0.1, validation_size=-1) ds.setup() ds.val_dataloader() # test that train_size + validation_size <= 1 with pytest.raises(ValueError): ds = DataSplitter(a, train_size=1, validation_size=0.1) ds.setup() ds.train_dataloader() ds.val_dataloader()
def train( self, max_epochs: int = 400, lr: float = 1e-3, use_gpu: Optional[Union[str, int, bool]] = None, train_size: float = 1, validation_size: Optional[float] = None, batch_size: int = 128, plan_kwargs: Optional[dict] = None, early_stopping: bool = True, early_stopping_patience: int = 30, early_stopping_min_delta: float = 0.0, **kwargs, ): """ Trains the model. Parameters ---------- max_epochs Number of epochs to train for lr Learning rate for optimization. use_gpu Use default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str, e.g., `'cuda:0'`), or use CPU (if False). train_size Size of training set in the range [0.0, 1.0]. validation_size Size of the test set. If `None`, defaults to 1 - `train_size`. If `train_size + validation_size < 1`, the remaining cells belong to a test set. batch_size Minibatch size to use during training. plan_kwargs Keyword args for :class:`~scvi.train.ClassifierTrainingPlan`. Keyword arguments passed to early_stopping Adds callback for early stopping on validation_loss early_stopping_patience Number of times early stopping metric can not improve over early_stopping_min_delta early_stopping_min_delta Threshold for counting an epoch torwards patience `train()` will overwrite values present in `plan_kwargs`, when appropriate. **kwargs Other keyword args for :class:`~scvi.train.Trainer`. """ update_dict = { "lr": lr, } if plan_kwargs is not None: plan_kwargs.update(update_dict) else: plan_kwargs = update_dict if early_stopping: early_stopping_callback = [ LoudEarlyStopping( monitor="validation_loss", min_delta=early_stopping_min_delta, patience=early_stopping_patience, mode="min", ) ] if "callbacks" in kwargs: kwargs["callbacks"] += early_stopping_callback else: kwargs["callbacks"] = early_stopping_callback kwargs["check_val_every_n_epoch"] = 1 if max_epochs is None: n_cells = self.adata.n_obs max_epochs = np.min([round((20000 / n_cells) * 400), 400]) plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict() data_splitter = DataSplitter( self.adata_manager, train_size=train_size, validation_size=validation_size, batch_size=batch_size, use_gpu=use_gpu, ) training_plan = ClassifierTrainingPlan(self.module, **plan_kwargs) runner = TrainRunner( self, training_plan=training_plan, data_splitter=data_splitter, max_epochs=max_epochs, use_gpu=use_gpu, **kwargs, ) return runner()
def train( self, max_epochs: int = 500, lr: float = 1e-4, use_gpu: Optional[Union[str, int, bool]] = None, train_size: float = 0.9, validation_size: Optional[float] = None, batch_size: int = 128, weight_decay: float = 1e-3, eps: float = 1e-08, early_stopping: bool = True, save_best: bool = True, check_val_every_n_epoch: Optional[int] = None, n_steps_kl_warmup: Optional[int] = None, n_epochs_kl_warmup: Optional[int] = 50, adversarial_mixing: bool = True, plan_kwargs: Optional[dict] = None, **kwargs, ): """ Trains the model using amortized variational inference. Parameters ---------- max_epochs Number of passes through the dataset. lr Learning rate for optimization. use_gpu Use default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False). train_size Size of training set in the range [0.0, 1.0]. validation_size Size of the test set. If `None`, defaults to 1 - `train_size`. If `train_size + validation_size < 1`, the remaining cells belong to a test set. batch_size Minibatch size to use during training. weight_decay weight decay regularization term for optimization eps Optimizer eps early_stopping Whether to perform early stopping with respect to the validation set. save_best Save the best model state with respect to the validation loss, or use the final state in the training procedure check_val_every_n_epoch Check val every n train epochs. By default, val is not checked, unless `early_stopping` is `True`. If so, val is checked every epoch. n_steps_kl_warmup Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when `n_epochs_kl_warmup` is set to None. If `None`, defaults to `floor(0.75 * adata.n_obs)`. n_epochs_kl_warmup Number of epochs to scale weight on KL divergences from 0 to 1. Overrides `n_steps_kl_warmup` when both are not `None`. adversarial_mixing Whether to use adversarial training to penalize the model for umbalanced mixing of modalities. plan_kwargs Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to `train()` will overwrite values present in `plan_kwargs`, when appropriate. **kwargs Other keyword args for :class:`~scvi.train.Trainer`. """ update_dict = dict( lr=lr, adversarial_classifier=adversarial_mixing, weight_decay=weight_decay, eps=eps, n_epochs_kl_warmup=n_epochs_kl_warmup, n_steps_kl_warmup=n_steps_kl_warmup, check_val_every_n_epoch=check_val_every_n_epoch, early_stopping=early_stopping, early_stopping_monitor="reconstruction_loss_validation", early_stopping_patience=50, optimizer="AdamW", scale_adversarial_loss=1, ) if plan_kwargs is not None: plan_kwargs.update(update_dict) else: plan_kwargs = update_dict if save_best: if "callbacks" not in kwargs.keys(): kwargs["callbacks"] = [] kwargs["callbacks"].append( SaveBestState(monitor="reconstruction_loss_validation")) data_splitter = DataSplitter( self.adata, train_size=train_size, validation_size=validation_size, batch_size=batch_size, use_gpu=use_gpu, ) training_plan = AdversarialTrainingPlan(self.module, **plan_kwargs) runner = TrainRunner( self, training_plan=training_plan, data_splitter=data_splitter, max_epochs=max_epochs, use_gpu=use_gpu, early_stopping=early_stopping, **kwargs, ) return runner()
def train( self, max_epochs: Optional[int] = 400, lr: float = 4e-3, use_gpu: Optional[Union[str, int, bool]] = None, train_size: float = 0.9, validation_size: Optional[float] = None, batch_size: int = 256, early_stopping: bool = True, check_val_every_n_epoch: Optional[int] = None, reduce_lr_on_plateau: bool = True, n_steps_kl_warmup: Union[int, None] = None, n_epochs_kl_warmup: Union[int, None] = None, adversarial_classifier: Optional[bool] = None, plan_kwargs: Optional[dict] = None, **kwargs, ): """ Trains the model using amortized variational inference. Parameters ---------- max_epochs Number of passes through the dataset. lr Learning rate for optimization. use_gpu Use default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False). train_size Size of training set in the range [0.0, 1.0]. validation_size Size of the test set. If `None`, defaults to 1 - `train_size`. If `train_size + validation_size < 1`, the remaining cells belong to a test set. batch_size Minibatch size to use during training. early_stopping Whether to perform early stopping with respect to the validation set. check_val_every_n_epoch Check val every n train epochs. By default, val is not checked, unless `early_stopping` is `True` or `reduce_lr_on_plateau` is `True`. If either of the latter conditions are met, val is checked every epoch. reduce_lr_on_plateau Reduce learning rate on plateau of validation metric (default is ELBO). n_steps_kl_warmup Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when `n_epochs_kl_warmup` is set to None. If `None`, defaults to `floor(0.75 * adata.n_obs)`. n_epochs_kl_warmup Number of epochs to scale weight on KL divergences from 0 to 1. Overrides `n_steps_kl_warmup` when both are not `None`. adversarial_classifier Whether to use adversarial classifier in the latent space. This helps mixing when there are missing proteins in any of the batches. Defaults to `True` is missing proteins are detected. plan_kwargs Keyword args for :class:`~scvi.train.AdversarialTrainingPlan`. Keyword arguments passed to `train()` will overwrite values present in `plan_kwargs`, when appropriate. **kwargs Other keyword args for :class:`~scvi.train.Trainer`. """ if adversarial_classifier is None: imputation = ( True if "totalvi_batch_mask" in self.scvi_setup_dict_.keys() else False ) adversarial_classifier = True if imputation else False n_steps_kl_warmup = ( n_steps_kl_warmup if n_steps_kl_warmup is not None else int(0.75 * self.adata.n_obs) ) if reduce_lr_on_plateau: check_val_every_n_epoch = 1 update_dict = { "lr": lr, "adversarial_classifier": adversarial_classifier, "reduce_lr_on_plateau": reduce_lr_on_plateau, "n_epochs_kl_warmup": n_epochs_kl_warmup, "n_steps_kl_warmup": n_steps_kl_warmup, "check_val_every_n_epoch": check_val_every_n_epoch, } if plan_kwargs is not None: plan_kwargs.update(update_dict) else: plan_kwargs = update_dict if max_epochs is None: n_cells = self.adata.n_obs max_epochs = np.min([round((20000 / n_cells) * 400), 400]) plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict() data_splitter = DataSplitter( self.adata, train_size=train_size, validation_size=validation_size, batch_size=batch_size, use_gpu=use_gpu, ) training_plan = AdversarialTrainingPlan( self.module, len(data_splitter.train_idx), **plan_kwargs ) runner = TrainRunner( self, training_plan=training_plan, data_splitter=data_splitter, max_epochs=max_epochs, use_gpu=use_gpu, early_stopping=early_stopping, **kwargs, ) return runner()
def train( self, max_epochs: int = 200, use_gpu: Optional[Union[str, int, bool]] = None, kappa: int = 5, train_size: float = 0.9, validation_size: Optional[float] = None, batch_size: int = 128, plan_kwargs: Optional[dict] = None, **kwargs, ): """ Train the model. Parameters ---------- max_epochs Number of passes through the dataset. If `None`, defaults to `np.min([round((20000 / n_cells) * 400), 400])` use_gpu Use default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str, e.g., `'cuda:0'`), or use CPU (if False). kappa Scaling parameter for the discriminator loss. train_size Size of training set in the range [0.0, 1.0]. validation_size Size of the test set. If `None`, defaults to 1 - `train_size`. If `train_size + validation_size < 1`, the remaining cells belong to a test set. batch_size Minibatch size to use during training. plan_kwargs Keyword args for model-specific Pytorch Lightning task. Keyword arguments passed to `train()` will overwrite values present in `plan_kwargs`, when appropriate. **kwargs Other keyword args for :class:`~scvi.train.Trainer`. """ gpus, device = parse_use_gpu_arg(use_gpu) self.trainer = Trainer( max_epochs=max_epochs, gpus=gpus, **kwargs, ) self.train_indices_, self.test_indices_, self.validation_indices_ = [], [], [] train_dls, test_dls, val_dls = [], [], [] for i, adm in enumerate(self.adata_managers.values()): ds = DataSplitter( adm, train_size=train_size, validation_size=validation_size, batch_size=batch_size, use_gpu=use_gpu, ) ds.setup() train_dls.append(ds.train_dataloader()) test_dls.append(ds.test_dataloader()) val = ds.val_dataloader() val_dls.append(val) val.mode = i self.train_indices_.append(ds.train_idx) self.test_indices_.append(ds.test_idx) self.validation_indices_.append(ds.val_idx) train_dl = TrainDL(train_dls) plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict() self._training_plan = GIMVITrainingPlan( self.module, adversarial_classifier=True, scale_adversarial_loss=kappa, **plan_kwargs, ) if train_size == 1.0: # circumvent the empty data loader problem if all dataset used for training self.trainer.fit(self._training_plan, train_dl) else: # accepts list of val dataloaders self.trainer.fit(self._training_plan, train_dl, val_dls) try: self.history_ = self.trainer.logger.history except AttributeError: self.history_ = None self.module.eval() self.to_device(device) self.is_trained_ = True
def train( self, max_epochs: Optional[int] = None, use_gpu: Optional[Union[str, int, bool]] = None, train_size: float = 0.9, validation_size: Optional[float] = None, batch_size: int = 128, early_stopping: bool = False, lr: Optional[float] = None, plan_kwargs: Optional[dict] = None, **trainer_kwargs, ): """ Train the model. Parameters ---------- max_epochs Number of passes through the dataset. If `None`, defaults to `np.min([round((20000 / n_cells) * 400), 400])` use_gpu Use default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str, e.g., `'cuda:0'`), or use CPU (if False). train_size Size of training set in the range [0.0, 1.0]. validation_size Size of the test set. If `None`, defaults to 1 - `train_size`. If `train_size + validation_size < 1`, the remaining cells belong to a test set. batch_size Minibatch size to use during training. If `None`, no minibatching occurs and all data is copied to device (e.g., GPU). early_stopping Perform early stopping. Additional arguments can be passed in `**kwargs`. See :class:`~scvi.train.Trainer` for further options. lr Optimiser learning rate (default optimiser is :class:`~pyro.optim.ClippedAdam`). Specifying optimiser via plan_kwargs overrides this choice of lr. plan_kwargs Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to `train()` will overwrite values present in `plan_kwargs`, when appropriate. **trainer_kwargs Other keyword args for :class:`~scvi.train.Trainer`. """ if max_epochs is None: n_obs = self.adata.n_obs max_epochs = np.min([round((20000 / n_obs) * 1000), 1000]) plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict() if lr is not None and "optim" not in plan_kwargs.keys(): plan_kwargs.update({"optim_kwargs": {"lr": lr}}) if batch_size is None: # use data splitter which moves data to GPU once data_splitter = DeviceBackedDataSplitter( self.adata, train_size=train_size, validation_size=validation_size, batch_size=batch_size, use_gpu=use_gpu, ) else: data_splitter = DataSplitter( self.adata, train_size=train_size, validation_size=validation_size, batch_size=batch_size, use_gpu=use_gpu, ) training_plan = PyroTrainingPlan(pyro_module=self.module, **plan_kwargs) es = "early_stopping" trainer_kwargs[es] = ( early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es] ) if "callbacks" not in trainer_kwargs.keys(): trainer_kwargs["callbacks"] = [] trainer_kwargs["callbacks"].append(PyroJitGuideWarmup()) runner = TrainRunner( self, training_plan=training_plan, data_splitter=data_splitter, max_epochs=max_epochs, use_gpu=use_gpu, **trainer_kwargs, ) return runner()