def test_pyro_bayesian_regression(save_path): use_gpu = int(torch.cuda.is_available()) adata = synthetic_iid() # add index for each cell (provided to pyro plate for correct minibatching) adata.obs["_indices"] = np.arange(adata.n_obs).astype("int64") register_tensor_from_anndata( adata, registry_key="ind_x", adata_attr_name="obs", adata_key_name="_indices", ) train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128) pyro.clear_param_store() model = BayesianRegressionModule(in_features=adata.shape[1], out_features=1) plan = PyroTrainingPlan(model) plan.n_obs_training = len(train_dl.indices) trainer = Trainer( gpus=use_gpu, max_epochs=2, ) trainer.fit(plan, train_dl) if use_gpu == 1: model.cuda() # test Predictive num_samples = 5 predictive = model.create_predictive(num_samples=num_samples) for tensor_dict in train_dl: args, kwargs = model._get_fn_args_from_batch(tensor_dict) _ = { k: v.detach().cpu().numpy() for k, v in predictive(*args, **kwargs).items() if k != "obs" } # test save and load # cpu/gpu has minor difference model.cpu() quants = model.guide.quantiles([0.5]) sigma_median = quants["sigma"][0].detach().cpu().numpy() linear_median = quants["linear.weight"][0].detach().cpu().numpy() model_save_path = os.path.join(save_path, "model_params.pt") torch.save(model.state_dict(), model_save_path) pyro.clear_param_store() new_model = BayesianRegressionModule(in_features=adata.shape[1], out_features=1) # run model one step to get autoguide params try: new_model.load_state_dict(torch.load(model_save_path)) except RuntimeError as err: if isinstance(new_model, PyroBaseModuleClass): plan = PyroTrainingPlan(new_model) plan.n_obs_training = len(train_dl.indices) trainer = Trainer( gpus=use_gpu, max_steps=1, ) trainer.fit(plan, train_dl) new_model.load_state_dict(torch.load(model_save_path)) else: raise err quants = new_model.guide.quantiles([0.5]) sigma_median_new = quants["sigma"][0].detach().cpu().numpy() linear_median_new = quants["linear.weight"][0].detach().cpu().numpy() np.testing.assert_array_equal(sigma_median_new, sigma_median) np.testing.assert_array_equal(linear_median_new, linear_median)
def train( self, max_epochs: Optional[int] = None, use_gpu: Optional[Union[str, int, bool]] = None, train_size: float = 0.9, validation_size: Optional[float] = None, batch_size: int = 128, early_stopping: bool = False, lr: Optional[float] = None, plan_kwargs: Optional[dict] = None, **trainer_kwargs, ): """ Train the model. Parameters ---------- max_epochs Number of passes through the dataset. If `None`, defaults to `np.min([round((20000 / n_cells) * 400), 400])` use_gpu Use default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str, e.g., `'cuda:0'`), or use CPU (if False). train_size Size of training set in the range [0.0, 1.0]. validation_size Size of the test set. If `None`, defaults to 1 - `train_size`. If `train_size + validation_size < 1`, the remaining cells belong to a test set. batch_size Minibatch size to use during training. If `None`, no minibatching occurs and all data is copied to device (e.g., GPU). early_stopping Perform early stopping. Additional arguments can be passed in `**kwargs`. See :class:`~scvi.train.Trainer` for further options. lr Optimiser learning rate (default optimiser is :class:`~pyro.optim.ClippedAdam`). Specifying optimiser via plan_kwargs overrides this choice of lr. plan_kwargs Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to `train()` will overwrite values present in `plan_kwargs`, when appropriate. **trainer_kwargs Other keyword args for :class:`~scvi.train.Trainer`. """ if max_epochs is None: n_obs = self.adata.n_obs max_epochs = np.min([round((20000 / n_obs) * 1000), 1000]) plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict() if lr is not None and "optim" not in plan_kwargs.keys(): plan_kwargs.update({"optim_kwargs": {"lr": lr}}) if batch_size is None: # use data splitter which moves data to GPU once data_splitter = DeviceBackedDataSplitter( self.adata, train_size=train_size, validation_size=validation_size, batch_size=batch_size, use_gpu=use_gpu, ) else: data_splitter = DataSplitter( self.adata, train_size=train_size, validation_size=validation_size, batch_size=batch_size, use_gpu=use_gpu, ) training_plan = PyroTrainingPlan(pyro_module=self.module, **plan_kwargs) es = "early_stopping" trainer_kwargs[es] = (early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es]) data_splitter.setup() if "callbacks" not in trainer_kwargs.keys(): trainer_kwargs["callbacks"] = [] trainer_kwargs["callbacks"].append( PyroJitGuideWarmup(data_splitter.train_dataloader())) runner = TrainRunner( self, training_plan=training_plan, data_splitter=data_splitter, max_epochs=max_epochs, use_gpu=use_gpu, **trainer_kwargs, ) return runner()