def test_pyro_bayesian_regression_jit(): use_gpu = int(torch.cuda.is_available()) adata = synthetic_iid() train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128) pyro.clear_param_store() model = BayesianRegressionModule(adata.shape[1], 1) train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128) plan = PyroTrainingPlan(model, loss_fn=pyro.infer.JitTrace_ELBO()) trainer = Trainer(gpus=use_gpu, max_epochs=2, callbacks=[PyroJitGuideWarmup(train_dl)]) trainer.fit(plan, train_dl) # 100 features, 1 for sigma, 1 for bias assert list(model.guide.parameters())[0].shape[0] == 102 if use_gpu == 1: model.cuda() # test Predictive num_samples = 5 predictive = model.create_predictive(num_samples=num_samples) for tensor_dict in train_dl: args, kwargs = model._get_fn_args_from_batch(tensor_dict) _ = { k: v.detach().cpu().numpy() for k, v in predictive(*args, **kwargs).items() if k != "obs" }
def test_pyro_bayesian_regression(save_path): use_gpu = int(torch.cuda.is_available()) adata = synthetic_iid() train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128) pyro.clear_param_store() model = BayesianRegressionModule(adata.shape[1], 1) plan = PyroTrainingPlan(model) plan.n_obs_training = len(train_dl.indices) trainer = Trainer( gpus=use_gpu, max_epochs=2, ) trainer.fit(plan, train_dl) if use_gpu == 1: model.cuda() # test Predictive num_samples = 5 predictive = model.create_predictive(num_samples=num_samples) for tensor_dict in train_dl: args, kwargs = model._get_fn_args_from_batch(tensor_dict) _ = { k: v.detach().cpu().numpy() for k, v in predictive(*args, **kwargs).items() if k != "obs" } # test save and load # cpu/gpu has minor difference model.cpu() quants = model.guide.quantiles([0.5]) sigma_median = quants["sigma"][0].detach().cpu().numpy() linear_median = quants["linear.weight"][0].detach().cpu().numpy() model_save_path = os.path.join(save_path, "model_params.pt") torch.save(model.state_dict(), model_save_path) pyro.clear_param_store() new_model = BayesianRegressionModule(adata.shape[1], 1) # run model one step to get autoguide params try: new_model.load_state_dict(torch.load(model_save_path)) except RuntimeError as err: if isinstance(new_model, PyroBaseModuleClass): plan = PyroTrainingPlan(new_model) plan.n_obs_training = len(train_dl.indices) trainer = Trainer( gpus=use_gpu, max_steps=1, ) trainer.fit(plan, train_dl) new_model.load_state_dict(torch.load(model_save_path)) else: raise err quants = new_model.guide.quantiles([0.5]) sigma_median_new = quants["sigma"][0].detach().cpu().numpy() linear_median_new = quants["linear.weight"][0].detach().cpu().numpy() np.testing.assert_array_equal(sigma_median_new, sigma_median) np.testing.assert_array_equal(linear_median_new, linear_median)
def test_pyro_bayesian_regression_jit(): use_gpu = int(torch.cuda.is_available()) adata = synthetic_iid() # add index for each cell (provided to pyro plate for correct minibatching) adata.obs["_indices"] = np.arange(adata.n_obs).astype("int64") register_tensor_from_anndata( adata, registry_key="ind_x", adata_attr_name="obs", adata_key_name="_indices", ) train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128) pyro.clear_param_store() model = BayesianRegressionModule(in_features=adata.shape[1], out_features=1) plan = PyroTrainingPlan(model, loss_fn=pyro.infer.JitTrace_ELBO()) plan.n_obs_training = len(train_dl.indices) trainer = Trainer(gpus=use_gpu, max_epochs=2, callbacks=[PyroJitGuideWarmup(train_dl)]) trainer.fit(plan, train_dl) # 100 features assert list(model.guide.state_dict() ["locs.linear.weight_unconstrained"].shape) == [ 1, 100, ] # 1 bias assert list( model.guide.state_dict()["locs.linear.bias_unconstrained"].shape) == [ 1, ] if use_gpu == 1: model.cuda() # test Predictive num_samples = 5 predictive = model.create_predictive(num_samples=num_samples) for tensor_dict in train_dl: args, kwargs = model._get_fn_args_from_batch(tensor_dict) _ = { k: v.detach().cpu().numpy() for k, v in predictive(*args, **kwargs).items() if k != "obs" }
def test_pyro_bayesian_regression_jit(): use_gpu = int(torch.cuda.is_available()) adata = synthetic_iid() adata_manager = _create_indices_adata_manager(adata) train_dl = AnnDataLoader(adata_manager, shuffle=True, batch_size=128) pyro.clear_param_store() model = BayesianRegressionModule(in_features=adata.shape[1], out_features=1) plan = PyroTrainingPlan(model, loss_fn=pyro.infer.JitTrace_ELBO()) plan.n_obs_training = len(train_dl.indices) trainer = Trainer(gpus=use_gpu, max_epochs=2, callbacks=[PyroJitGuideWarmup(train_dl)]) trainer.fit(plan, train_dl) # 100 features assert list(model.guide.state_dict() ["locs.linear.weight_unconstrained"].shape) == [ 1, 100, ] # 1 bias assert list( model.guide.state_dict()["locs.linear.bias_unconstrained"].shape) == [ 1, ] if use_gpu == 1: model.cuda() # test Predictive num_samples = 5 predictive = model.create_predictive(num_samples=num_samples) for tensor_dict in train_dl: args, kwargs = model._get_fn_args_from_batch(tensor_dict) _ = { k: v.detach().cpu().numpy() for k, v in predictive(*args, **kwargs).items() if k != "obs" }
class TrainRunner: """ TrainRunner calls Trainer.fit() and handles pre and post training procedures. Parameters ---------- model model to train training_plan initialized TrainingPlan data_splitter initialized :class:`~scvi.dataloaders.SemiSupervisedDataSplitter` or :class:`~scvi.dataloaders.DataSplitter` max_epochs max_epochs to train for use_gpu Use default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str, e.g., `'cuda:0'`), or use CPU (if False). trainer_kwargs Extra kwargs for :class:`~scvi.train.Trainer` Examples -------- >>> # Following code should be within a subclass of BaseModelClass >>> data_splitter = DataSplitter(self.adata) >>> training_plan = TrainingPlan(self.module, len(data_splitter.train_idx)) >>> runner = TrainRunner( >>> self, >>> training_plan=trianing_plan, >>> data_splitter=data_splitter, >>> max_epochs=max_epochs) >>> runner() """ def __init__( self, model: BaseModelClass, training_plan: pl.LightningModule, data_splitter: Union[SemiSupervisedDataSplitter, DataSplitter], max_epochs: int, use_gpu: Optional[Union[str, int, bool]] = None, **trainer_kwargs, ): self.training_plan = training_plan self.data_splitter = data_splitter self.model = model gpus, device = parse_use_gpu_arg(use_gpu) self.gpus = gpus self.device = device self.trainer = Trainer(max_epochs=max_epochs, gpus=gpus, **trainer_kwargs) def __call__(self): if hasattr(self.data_splitter, "n_train"): self.training_plan.n_obs_training = self.data_splitter.n_train self.trainer.fit(self.training_plan, self.data_splitter) self._update_history() # data splitter only gets these attrs after fit self.model.train_indices = self.data_splitter.train_idx self.model.test_indices = self.data_splitter.test_idx self.model.validation_indices = self.data_splitter.val_idx self.model.module.eval() self.model.is_trained_ = True self.model.to_device(self.device) self.model.trainer = self.trainer def _update_history(self): # model is being further trained # this was set to true during first training session if self.model.is_trained_ is True: # if not using the default logger (e.g., tensorboard) if not isinstance(self.model.history_, dict): warnings.warn( "Training history cannot be updated. Logger can be accessed from model.trainer.logger" ) return else: new_history = self.trainer.logger.history for key, val in self.model.history_.items(): # e.g., no validation loss due to training params if key not in new_history: continue prev_len = len(val) new_len = len(new_history[key]) index = np.arange(prev_len, prev_len + new_len) new_history[key].index = index self.model.history_[key] = pd.concat( [ val, new_history[key], ] ) self.model.history_[key].index.name = val.index.name else: # set history_ attribute if it exists # other pytorch lightning loggers might not have history attr try: self.model.history_ = self.trainer.logger.history except AttributeError: self.history_ = None
class TrainRunner: """ TrainRunner calls Trainer.fit() and handles pre and post training procedures. Parameters ---------- model model to train training_plan initialized TrainingPlan data_splitter initialized :class:`~scvi.dataloaders.SemiSupervisedDataSplitter` or :class:`~scvi.dataloaders.DataSplitter` max_epochs max_epochs to train for use_gpu Use default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False). trainer_kwargs Extra kwargs for :class:`~scvi.train.Trainer` Examples -------- >>> # Following code should be within a subclass of BaseModelClass >>> data_splitter = DataSplitter(self.adata) >>> training_plan = TrainingPlan(self.module, len(data_splitter.train_idx)) >>> runner = TrainRunner( >>> self, >>> training_plan=trianing_plan, >>> data_splitter=data_splitter, >>> max_epochs=max_epochs) >>> runner() """ def __init__( self, model: BaseModelClass, training_plan: pl.LightningModule, data_splitter: Union[SemiSupervisedDataSplitter, DataSplitter], max_epochs: int, use_gpu: Optional[Union[str, int, bool]] = None, **trainer_kwargs, ): self.training_plan = training_plan self.data_splitter = data_splitter self.model = model gpus, device = parse_use_gpu_arg(use_gpu) self.gpus = gpus self.device = device self.trainer = Trainer(max_epochs=max_epochs, gpus=gpus, **trainer_kwargs) def __call__(self): train_dl, val_dl, test_dl = self.data_splitter() self.model.train_indices = train_dl.indices self.model.test_indices = test_dl.indices self.model.validation_indices = val_dl.indices if len(val_dl.indices) == 0: # circumvent the empty data loader problem if all dataset used for training self.trainer.fit(self.training_plan, train_dl) else: self.trainer.fit(self.training_plan, train_dl, val_dl) try: self.model.history_ = self.trainer.logger.history except AttributeError: self.history_ = None self.model.module.eval() self.model.is_trained_ = True self.model.to_device(self.device) self.model.trainer = self.trainer
class GIMVI(VAEMixin, BaseModelClass): """ Joint VAE for imputing missing genes in spatial data [Lopez19]_. Parameters ---------- adata_seq AnnData object that has been registered via :meth:`~scvi.external.GIMVI.setup_anndata` and contains RNA-seq data. adata_spatial AnnData object that has been registered via :meth:`~scvi.external.GIMVI.setup_anndata` and contains spatial data. n_hidden Number of nodes per hidden layer. generative_distributions List of generative distribution for adata_seq data and adata_spatial data. model_library_size List of bool of whether to model library size for adata_seq and adata_spatial. n_latent Dimensionality of the latent space. **model_kwargs Keyword args for :class:`~scvi.external.gimvi.JVAE` Examples -------- >>> adata_seq = anndata.read_h5ad(path_to_anndata_seq) >>> adata_spatial = anndata.read_h5ad(path_to_anndata_spatial) >>> scvi.external.GIMVI.setup_anndata(adata_seq) >>> scvi.external.GIMVI.setup_anndata(adata_spatial) >>> vae = scvi.model.GIMVI(adata_seq, adata_spatial) >>> vae.train(n_epochs=400) Notes ----- See further usage examples in the following tutorials: 1. :doc:`/user_guide/notebooks/gimvi_tutorial` """ def __init__( self, adata_seq: AnnData, adata_spatial: AnnData, generative_distributions: List = ["zinb", "nb"], model_library_size: List = [True, False], n_latent: int = 10, **model_kwargs, ): super(GIMVI, self).__init__() if adata_seq is adata_spatial: raise ValueError( "`adata_seq` and `adata_spatial` cannot point to the same object. " "If you would really like to do this, make a copy of the object and pass it in as `adata_spatial`." ) self.adatas = [adata_seq, adata_spatial] self.adata_managers = { "seq": self._get_most_recent_anndata_manager(adata_seq, required=True), "spatial": self._get_most_recent_anndata_manager( adata_spatial, required=True ), } self.registries_ = [] for adm in self.adata_managers.values(): self._register_manager_for_instance(adm) self.registries_.append(adm.registry) seq_var_names = adata_seq.var_names spatial_var_names = adata_spatial.var_names if not set(spatial_var_names) <= set(seq_var_names): raise ValueError("spatial genes needs to be subset of seq genes") spatial_gene_loc = [ np.argwhere(seq_var_names == g)[0] for g in spatial_var_names ] spatial_gene_loc = np.concatenate(spatial_gene_loc) gene_mappings = [slice(None), spatial_gene_loc] sum_stats = [adm.summary_stats for adm in self.adata_managers.values()] n_inputs = [s["n_vars"] for s in sum_stats] total_genes = n_inputs[0] # since we are combining datasets, we need to increment the batch_idx # of one of the datasets adata_seq_n_batches = sum_stats[0]["n_batch"] adata_spatial.obs[ self.adata_managers["spatial"] .data_registry[REGISTRY_KEYS.BATCH_KEY] .attr_key ] += adata_seq_n_batches n_batches = sum(s["n_batch"] for s in sum_stats) library_log_means = [] library_log_vars = [] for adata_manager in self.adata_managers.values(): adata_library_log_means, adata_library_log_vars = _init_library_size( adata_manager, n_batches ) library_log_means.append(adata_library_log_means) library_log_vars.append(adata_library_log_vars) self.module = JVAE( n_inputs, total_genes, gene_mappings, generative_distributions, model_library_size, library_log_means, library_log_vars, n_batch=n_batches, n_latent=n_latent, **model_kwargs, ) self._model_summary_string = ( "GimVI Model with the following params: \nn_latent: {}, n_inputs: {}, n_genes: {}, " + "n_batch: {}, generative distributions: {}" ).format(n_latent, n_inputs, total_genes, n_batches, generative_distributions) self.init_params_ = self._get_init_params(locals()) def train( self, max_epochs: int = 200, use_gpu: Optional[Union[str, int, bool]] = None, kappa: int = 5, train_size: float = 0.9, validation_size: Optional[float] = None, batch_size: int = 128, plan_kwargs: Optional[dict] = None, **kwargs, ): """ Train the model. Parameters ---------- max_epochs Number of passes through the dataset. If `None`, defaults to `np.min([round((20000 / n_cells) * 400), 400])` use_gpu Use default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str, e.g., `'cuda:0'`), or use CPU (if False). kappa Scaling parameter for the discriminator loss. train_size Size of training set in the range [0.0, 1.0]. validation_size Size of the test set. If `None`, defaults to 1 - `train_size`. If `train_size + validation_size < 1`, the remaining cells belong to a test set. batch_size Minibatch size to use during training. plan_kwargs Keyword args for model-specific Pytorch Lightning task. Keyword arguments passed to `train()` will overwrite values present in `plan_kwargs`, when appropriate. **kwargs Other keyword args for :class:`~scvi.train.Trainer`. """ gpus, device = parse_use_gpu_arg(use_gpu) self.trainer = Trainer( max_epochs=max_epochs, gpus=gpus, **kwargs, ) self.train_indices_, self.test_indices_, self.validation_indices_ = [], [], [] train_dls, test_dls, val_dls = [], [], [] for i, adm in enumerate(self.adata_managers.values()): ds = DataSplitter( adm, train_size=train_size, validation_size=validation_size, batch_size=batch_size, use_gpu=use_gpu, ) ds.setup() train_dls.append(ds.train_dataloader()) test_dls.append(ds.test_dataloader()) val = ds.val_dataloader() val_dls.append(val) val.mode = i self.train_indices_.append(ds.train_idx) self.test_indices_.append(ds.test_idx) self.validation_indices_.append(ds.val_idx) train_dl = TrainDL(train_dls) plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict() self._training_plan = GIMVITrainingPlan( self.module, adversarial_classifier=True, scale_adversarial_loss=kappa, **plan_kwargs, ) if train_size == 1.0: # circumvent the empty data loader problem if all dataset used for training self.trainer.fit(self._training_plan, train_dl) else: # accepts list of val dataloaders self.trainer.fit(self._training_plan, train_dl, val_dls) try: self.history_ = self.trainer.logger.history except AttributeError: self.history_ = None self.module.eval() self.to_device(device) self.is_trained_ = True def _make_scvi_dls(self, adatas: List[AnnData] = None, batch_size=128): if adatas is None: adatas = self.adatas post_list = [self._make_data_loader(ad) for ad in adatas] for i, dl in enumerate(post_list): dl.mode = i return post_list @torch.no_grad() def get_latent_representation( self, adatas: List[AnnData] = None, deterministic: bool = True, batch_size: int = 128, ) -> List[np.ndarray]: """ Return the latent space embedding for each dataset. Parameters ---------- adatas List of adata seq and adata spatial. deterministic If true, use the mean of the encoder instead of a Gaussian sample. batch_size Minibatch size for data loading into model. """ if adatas is None: adatas = self.adatas scdls = self._make_scvi_dls(adatas, batch_size=batch_size) self.module.eval() latents = [] for mode, scdl in enumerate(scdls): latent = [] for tensors in scdl: ( sample_batch, *_, ) = _unpack_tensors(tensors) latent.append( self.module.sample_from_posterior_z( sample_batch, mode, deterministic=deterministic ) ) latent = torch.cat(latent).cpu().detach().numpy() latents.append(latent) return latents @torch.no_grad() def get_imputed_values( self, adatas: List[AnnData] = None, deterministic: bool = True, normalized: bool = True, decode_mode: Optional[int] = None, batch_size: int = 128, ) -> List[np.ndarray]: """ Return imputed values for all genes for each dataset. Parameters ---------- adatas List of adata seq and adata spatial deterministic If true, use the mean of the encoder instead of a Gaussian sample for the latent vector. normalized Return imputed normalized values or not. decode_mode If a `decode_mode` is given, use the encoder specific to each dataset as usual but use the decoder of the dataset of id `decode_mode` to impute values. batch_size Minibatch size for data loading into model. """ self.module.eval() if adatas is None: adatas = self.adatas scdls = self._make_scvi_dls(adatas, batch_size=batch_size) imputed_values = [] for mode, scdl in enumerate(scdls): imputed_value = [] for tensors in scdl: ( sample_batch, batch_index, label, *_, ) = _unpack_tensors(tensors) if normalized: imputed_value.append( self.module.sample_scale( sample_batch, mode, batch_index, label, deterministic=deterministic, decode_mode=decode_mode, ) ) else: imputed_value.append( self.module.sample_rate( sample_batch, mode, batch_index, label, deterministic=deterministic, decode_mode=decode_mode, ) ) imputed_value = torch.cat(imputed_value).cpu().detach().numpy() imputed_values.append(imputed_value) return imputed_values def save( self, dir_path: str, prefix: Optional[str] = None, overwrite: bool = False, save_anndata: bool = False, **anndata_write_kwargs, ): """ Save the state of the model. Neither the trainer optimizer state nor the trainer history are saved. Model files are not expected to be reproducibly saved and loaded across versions until we reach version 1.0. Parameters ---------- dir_path Path to a directory. prefix Prefix to prepend to saved file names. overwrite Overwrite existing data or not. If `False` and directory already exists at `dir_path`, error will be raised. save_anndata If True, also saves the anndata anndata_write_kwargs Kwargs for anndata write function """ if not os.path.exists(dir_path) or overwrite: os.makedirs(dir_path, exist_ok=overwrite) else: raise ValueError( "{} already exists. Please provide an unexisting directory for saving.".format( dir_path ) ) file_name_prefix = prefix or "" seq_adata = self.adatas[0] spatial_adata = self.adatas[1] if save_anndata: seq_save_path = os.path.join(dir_path, f"{file_name_prefix}adata_seq.h5ad") seq_adata.write(seq_save_path) spatial_save_path = os.path.join( dir_path, f"{file_name_prefix}adata_spatial.h5ad" ) spatial_adata.write(spatial_save_path) # save the model state dict and the trainer state dict only model_state_dict = self.module.state_dict() seq_var_names = seq_adata.var_names.astype(str).to_numpy() spatial_var_names = spatial_adata.var_names.astype(str).to_numpy() # get all the user attributes user_attributes = self._get_user_attributes() # only save the public attributes with _ at the very end user_attributes = {a[0]: a[1] for a in user_attributes if a[0][-1] == "_"} model_save_path = os.path.join(dir_path, f"{file_name_prefix}model.pt") torch.save( dict( model_state_dict=model_state_dict, seq_var_names=seq_var_names, spatial_var_names=spatial_var_names, attr_dict=user_attributes, ), model_save_path, ) @classmethod def load( cls, dir_path: str, prefix: Optional[str] = None, adata_seq: Optional[AnnData] = None, adata_spatial: Optional[AnnData] = None, use_gpu: Optional[Union[str, int, bool]] = None, ): """ Instantiate a model from the saved output. Parameters ---------- dir_path Path to saved outputs. prefix Prefix of saved file names. adata_seq AnnData organized in the same way as data used to train model. It is not necessary to run :meth:`~scvi.external.GIMVI.setup_anndata`, as AnnData is validated against the saved `scvi` setup dictionary. AnnData must be registered via :meth:`~scvi.external.GIMVI.setup_anndata`. adata_spatial AnnData organized in the same way as data used to train model. If None, will check for and load anndata saved with the model. use_gpu Load model on default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False). Returns ------- Model with loaded state dictionaries. Examples -------- >>> vae = GIMVI.load(adata_seq, adata_spatial, save_path) >>> vae.get_latent_representation() """ _, device = parse_use_gpu_arg(use_gpu) ( attr_dict, seq_var_names, spatial_var_names, model_state_dict, loaded_adata_seq, loaded_adata_spatial, ) = _load_saved_gimvi_files( dir_path, adata_seq is None, adata_spatial is None, prefix=prefix, map_location=device, ) adata_seq = loaded_adata_seq or adata_seq adata_spatial = loaded_adata_spatial or adata_spatial adatas = [adata_seq, adata_spatial] var_names = [seq_var_names, spatial_var_names] for i, adata in enumerate(adatas): saved_var_names = var_names[i] user_var_names = adata.var_names.astype(str) if not np.array_equal(saved_var_names, user_var_names): warnings.warn( "var_names for adata passed in does not match var_names of " "adata used to train the model. For valid results, the vars " "need to be the same and in the same order as the adata used to train the model." ) if "scvi_setup_dicts_" in attr_dict: scvi_setup_dicts = attr_dict.pop("scvi_setup_dicts_") for adata, scvi_setup_dict in zip(adatas, scvi_setup_dicts): cls.register_manager( manager_from_setup_dict(cls, adata, scvi_setup_dict) ) else: registries = attr_dict.pop("registries_") for adata, registry in zip(adatas, registries): if ( _MODEL_NAME_KEY in registry and registry[_MODEL_NAME_KEY] != cls.__name__ ): raise ValueError( "It appears you are loading a model from a different class." ) if _SETUP_KWARGS_KEY not in registry: raise ValueError( "Saved model does not contain original setup inputs. " "Cannot load the original setup." ) cls.setup_anndata( adata, source_registry=registry, **registry[_SETUP_KWARGS_KEY] ) # get the parameters for the class init signiture init_params = attr_dict.pop("init_params_") # new saving and loading, enable backwards compatibility if "non_kwargs" in init_params.keys(): # grab all the parameters execept for kwargs (is a dict) non_kwargs = init_params["non_kwargs"] kwargs = init_params["kwargs"] # expand out kwargs kwargs = {k: v for (i, j) in kwargs.items() for (k, v) in j.items()} else: # grab all the parameters execept for kwargs (is a dict) non_kwargs = { k: v for k, v in init_params.items() if not isinstance(v, dict) } kwargs = {k: v for k, v in init_params.items() if isinstance(v, dict)} kwargs = {k: v for (i, j) in kwargs.items() for (k, v) in j.items()} model = cls(adata_seq, adata_spatial, **non_kwargs, **kwargs) for attr, val in attr_dict.items(): setattr(model, attr, val) model.module.load_state_dict(model_state_dict) model.module.eval() model.to_device(device) return model @classmethod @setup_anndata_dsp.dedent def setup_anndata( cls, adata: AnnData, batch_key: Optional[str] = None, labels_key: Optional[str] = None, layer: Optional[str] = None, **kwargs, ): """ %(summary)s. Parameters ---------- %(param_batch_key)s %(param_labels_key)s %(param_layer)s """ setup_method_args = cls._get_setup_method_args(**locals()) anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), ] adata_manager = AnnDataManager( fields=anndata_fields, setup_method_args=setup_method_args ) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager)
class GIMVI(VAEMixin, BaseModelClass): """ Joint VAE for imputing missing genes in spatial data [Lopez19]_. Parameters ---------- adata_seq AnnData object that has been registered via :func:`~scvi.data.setup_anndata` and contains RNA-seq data. adata_spatial AnnData object that has been registered via :func:`~scvi.data.setup_anndata` and contains spatial data. n_hidden Number of nodes per hidden layer. generative_distributions List of generative distribution for adata_seq data and adata_spatial data. model_library_size List of bool of whether to model library size for adata_seq and adata_spatial. n_latent Dimensionality of the latent space. **model_kwargs Keyword args for :class:`~scvi.external.gimvi.JVAE` Examples -------- >>> adata_seq = anndata.read_h5ad(path_to_anndata_seq) >>> adata_spatial = anndata.read_h5ad(path_to_anndata_spatial) >>> scvi.data.setup_anndata(adata_seq) >>> scvi.data.setup_anndata(adata_spatial) >>> vae = scvi.model.GIMVI(adata_seq, adata_spatial) >>> vae.train(n_epochs=400) Notes ----- See further usage examples in the following tutorials: 1. :doc:`/user_guide/notebooks/gimvi_tutorial` """ def __init__( self, adata_seq: AnnData, adata_spatial: AnnData, generative_distributions: List = ["zinb", "nb"], model_library_size: List = [True, False], n_latent: int = 10, **model_kwargs, ): super(GIMVI, self).__init__() self.adatas = [adata_seq, adata_spatial] self.scvi_setup_dicts_ = { "seq": adata_seq.uns["_scvi"], "spatial": adata_spatial.uns["_scvi"], } seq_var_names = _get_var_names_from_setup_anndata(adata_seq) spatial_var_names = _get_var_names_from_setup_anndata(adata_spatial) if not set(spatial_var_names) <= set(seq_var_names): raise ValueError("spatial genes needs to be subset of seq genes") spatial_gene_loc = [ np.argwhere(seq_var_names == g)[0] for g in spatial_var_names ] spatial_gene_loc = np.concatenate(spatial_gene_loc) gene_mappings = [slice(None), spatial_gene_loc] sum_stats = [d.uns["_scvi"]["summary_stats"] for d in self.adatas] n_inputs = [s["n_vars"] for s in sum_stats] total_genes = adata_seq.uns["_scvi"]["summary_stats"]["n_vars"] # since we are combining datasets, we need to increment the batch_idx # of one of the datasets adata_seq_n_batches = adata_seq.uns["_scvi"]["summary_stats"]["n_batch"] adata_spatial.obs["_scvi_batch"] += adata_seq_n_batches n_batches = sum([s["n_batch"] for s in sum_stats]) self.module = JVAE( n_inputs, total_genes, gene_mappings, generative_distributions, model_library_size, n_batch=n_batches, n_latent=n_latent, **model_kwargs, ) self._model_summary_string = ( "GimVI Model with the following params: \nn_latent: {}, n_inputs: {}, n_genes: {}, " + "n_batch: {}, generative distributions: {}" ).format(n_latent, n_inputs, total_genes, n_batches, generative_distributions) self.init_params_ = self._get_init_params(locals()) def train( self, max_epochs: int = 200, use_gpu: Optional[Union[str, int, bool]] = None, kappa: int = 5, train_size: float = 0.9, validation_size: Optional[float] = None, batch_size: int = 128, plan_kwargs: Optional[dict] = None, **kwargs, ): """ Train the model. Parameters ---------- max_epochs Number of passes through the dataset. If `None`, defaults to `np.min([round((20000 / n_cells) * 400), 400])` use_gpu Use default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str, e.g., `'cuda:0'`), or use CPU (if False). kappa Scaling parameter for the discriminator loss. train_size Size of training set in the range [0.0, 1.0]. validation_size Size of the test set. If `None`, defaults to 1 - `train_size`. If `train_size + validation_size < 1`, the remaining cells belong to a test set. batch_size Minibatch size to use during training. plan_kwargs Keyword args for model-specific Pytorch Lightning task. Keyword arguments passed to `train()` will overwrite values present in `plan_kwargs`, when appropriate. **kwargs Other keyword args for :class:`~scvi.train.Trainer`. """ gpus, device = parse_use_gpu_arg(use_gpu) self.trainer = Trainer( max_epochs=max_epochs, gpus=gpus, **kwargs, ) self.train_indices_, self.test_indices_, self.validation_indices_ = [], [], [] train_dls, test_dls, val_dls = [], [], [] for i, ad in enumerate(self.adatas): ds = DataSplitter( ad, train_size=train_size, validation_size=validation_size, batch_size=batch_size, use_gpu=use_gpu, ) ds.setup() train_dls.append(ds.train_dataloader()) test_dls.append(ds.test_dataloader()) val = ds.val_dataloader() val_dls.append(val) val.mode = i self.train_indices_.append(ds.train_idx) self.test_indices_.append(ds.test_idx) self.validation_indices_.append(ds.val_idx) train_dl = TrainDL(train_dls) plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict() self._training_plan = GIMVITrainingPlan( self.module, adversarial_classifier=True, scale_adversarial_loss=kappa, **plan_kwargs, ) if train_size == 1.0: # circumvent the empty data loader problem if all dataset used for training self.trainer.fit(self._training_plan, train_dl) else: # accepts list of val dataloaders self.trainer.fit(self._training_plan, train_dl, val_dls) try: self.history_ = self.trainer.logger.history except AttributeError: self.history_ = None self.module.eval() self.to_device(device) self.is_trained_ = True def _make_scvi_dls(self, adatas: List[AnnData] = None, batch_size=128): if adatas is None: adatas = self.adatas post_list = [self._make_data_loader(ad) for ad in adatas] for i, dl in enumerate(post_list): dl.mode = i return post_list @torch.no_grad() def get_latent_representation( self, adatas: List[AnnData] = None, deterministic: bool = True, batch_size: int = 128, ) -> List[np.ndarray]: """ Return the latent space embedding for each dataset. Parameters ---------- adatas List of adata seq and adata spatial. deterministic If true, use the mean of the encoder instead of a Gaussian sample. batch_size Minibatch size for data loading into model. """ if adatas is None: adatas = self.adatas scdls = self._make_scvi_dls(adatas, batch_size=batch_size) self.module.eval() latents = [] for mode, scdl in enumerate(scdls): latent = [] for tensors in scdl: ( sample_batch, local_l_mean, local_l_var, batch_index, label, *_, ) = _unpack_tensors(tensors) latent.append( self.module.sample_from_posterior_z( sample_batch, mode, deterministic=deterministic ) ) latent = torch.cat(latent).cpu().detach().numpy() latents.append(latent) return latents @torch.no_grad() def get_imputed_values( self, adatas: List[AnnData] = None, deterministic: bool = True, normalized: bool = True, decode_mode: Optional[int] = None, batch_size: int = 128, ) -> List[np.ndarray]: """ Return imputed values for all genes for each dataset. Parameters ---------- adatas List of adata seq and adata spatial deterministic If true, use the mean of the encoder instead of a Gaussian sample for the latent vector. normalized Return imputed normalized values or not. decode_mode If a `decode_mode` is given, use the encoder specific to each dataset as usual but use the decoder of the dataset of id `decode_mode` to impute values. batch_size Minibatch size for data loading into model. """ self.module.eval() if adatas is None: adatas = self.adatas scdls = self._make_scvi_dls(adatas, batch_size=batch_size) imputed_values = [] for mode, scdl in enumerate(scdls): imputed_value = [] for tensors in scdl: ( sample_batch, local_l_mean, local_l_var, batch_index, label, *_, ) = _unpack_tensors(tensors) if normalized: imputed_value.append( self.module.sample_scale( sample_batch, mode, batch_index, label, deterministic=deterministic, decode_mode=decode_mode, ) ) else: imputed_value.append( self.module.sample_rate( sample_batch, mode, batch_index, label, deterministic=deterministic, decode_mode=decode_mode, ) ) imputed_value = torch.cat(imputed_value).cpu().detach().numpy() imputed_values.append(imputed_value) return imputed_values def save( self, dir_path: str, overwrite: bool = False, save_anndata: bool = False, **anndata_write_kwargs, ): """ Save the state of the model. Neither the trainer optimizer state nor the trainer history are saved. Model files are not expected to be reproducibly saved and loaded across versions until we reach version 1.0. Parameters ---------- dir_path Path to a directory. overwrite Overwrite existing data or not. If `False` and directory already exists at `dir_path`, error will be raised. save_anndata If True, also saves the anndata anndata_write_kwargs Kwargs for anndata write function """ # get all the user attributes user_attributes = self._get_user_attributes() # only save the public attributes with _ at the very end user_attributes = {a[0]: a[1] for a in user_attributes if a[0][-1] == "_"} # save the model state dict and the trainer state dict only if not os.path.exists(dir_path) or overwrite: os.makedirs(dir_path, exist_ok=overwrite) else: raise ValueError( "{} already exists. Please provide an unexisting directory for saving.".format( dir_path ) ) if save_anndata: dataset_names = ["seq", "spatial"] for i in range(len(self.adatas)): save_path = os.path.join( dir_path, "adata_{}.h5ad".format(dataset_names[i]) ) self.adatas[i].write(save_path) varnames_save_path = os.path.join( dir_path, "var_names_{}.csv".format(dataset_names[i]) ) var_names = self.adatas[i].var_names.astype(str) var_names = var_names.to_numpy() np.savetxt(varnames_save_path, var_names, fmt="%s") model_save_path = os.path.join(dir_path, "model_params.pt") attr_save_path = os.path.join(dir_path, "attr.pkl") torch.save(self.module.state_dict(), model_save_path) with open(attr_save_path, "wb") as f: pickle.dump(user_attributes, f) @classmethod def load( cls, dir_path: str, adata_seq: Optional[AnnData] = None, adata_spatial: Optional[AnnData] = None, use_gpu: Optional[Union[str, int, bool]] = None, ): """ Instantiate a model from the saved output. Parameters ---------- adata_seq AnnData organized in the same way as data used to train model. It is not necessary to run :func:`~scvi.data.setup_anndata`, as AnnData is validated against the saved `scvi` setup dictionary. AnnData must be registered via :func:`~scvi.data.setup_anndata`. adata_spatial AnnData organized in the same way as data used to train model. If None, will check for and load anndata saved with the model. dir_path Path to saved outputs. use_gpu Load model on default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False). Returns ------- Model with loaded state dictionaries. Examples -------- >>> vae = GIMVI.load(adata_seq, adata_spatial, save_path) >>> vae.get_latent_representation() """ model_path = os.path.join(dir_path, "model_params.pt") setup_dict_path = os.path.join(dir_path, "attr.pkl") seq_data_path = os.path.join(dir_path, "adata_seq.h5ad") spatial_data_path = os.path.join(dir_path, "adata_spatial.h5ad") seq_var_names_path = os.path.join(dir_path, "var_names_seq.csv") spatial_var_names_path = os.path.join(dir_path, "var_names_spatial.csv") if adata_seq is None and os.path.exists(seq_data_path): adata_seq = read(seq_data_path) elif adata_seq is None and not os.path.exists(seq_data_path): raise ValueError( "Save path contains no saved anndata and no adata was passed." ) if adata_spatial is None and os.path.exists(spatial_data_path): adata_spatial = read(spatial_data_path) elif adata_spatial is None and not os.path.exists(spatial_data_path): raise ValueError( "Save path contains no saved anndata and no adata was passed." ) adatas = [adata_seq, adata_spatial] seq_var_names = np.genfromtxt(seq_var_names_path, delimiter=",", dtype=str) spatial_var_names = np.genfromtxt( spatial_var_names_path, delimiter=",", dtype=str ) var_names = [seq_var_names, spatial_var_names] for i, adata in enumerate(adatas): saved_var_names = var_names[i] user_var_names = adata.var_names.astype(str) if not np.array_equal(saved_var_names, user_var_names): warnings.warn( "var_names for adata passed in does not match var_names of " "adata used to train the model. For valid results, the vars " "need to be the same and in the same order as the adata used to train the model." ) with open(setup_dict_path, "rb") as handle: attr_dict = pickle.load(handle) scvi_setup_dicts = attr_dict.pop("scvi_setup_dicts_") transfer_anndata_setup(scvi_setup_dicts["seq"], adata_seq) transfer_anndata_setup(scvi_setup_dicts["spatial"], adata_spatial) # get the parameters for the class init signiture init_params = attr_dict.pop("init_params_") # new saving and loading, enable backwards compatibility if "non_kwargs" in init_params.keys(): # grab all the parameters execept for kwargs (is a dict) non_kwargs = init_params["non_kwargs"] kwargs = init_params["kwargs"] # expand out kwargs kwargs = {k: v for (i, j) in kwargs.items() for (k, v) in j.items()} else: # grab all the parameters execept for kwargs (is a dict) non_kwargs = { k: v for k, v in init_params.items() if not isinstance(v, dict) } kwargs = {k: v for k, v in init_params.items() if isinstance(v, dict)} kwargs = {k: v for (i, j) in kwargs.items() for (k, v) in j.items()} model = cls(adata_seq, adata_spatial, **non_kwargs, **kwargs) for attr, val in attr_dict.items(): setattr(model, attr, val) _, device = parse_use_gpu_arg(use_gpu) model.module.load_state_dict(torch.load(model_path, map_location=device)) model.module.eval() model.to_device(device) return model