def _posterior_quantile(self, q: float = 0.5, batch_size: int = 2048, use_gpu: bool = True): """ Compute median of the posterior distribution of each parameter pyro models trained without amortised inference. Parameters ---------- q quantile to compute use_gpu Bool, use gpu? Returns ------- dictionary {variable_name: posterior median} """ self.module.eval() gpus, device = parse_use_gpu_arg(use_gpu) train_dl = AnnDataLoader(self.adata, shuffle=False, batch_size=batch_size) # sample global parameters tensor_dict = next(iter(train_dl)) args, kwargs = self.module._get_fn_args_from_batch(tensor_dict) args = [a.to(device) for a in args] kwargs = {k: v.to(device) for k, v in kwargs.items()} self.to_device(device) means = self.module.guide.quantiles([q], *args, **kwargs) means = {k: means[k].cpu().detach().numpy() for k in means.keys()} return means
def __call__(self, remake_splits=False): if remake_splits: self.train_idx, self.test_idx, self.val_idx = self.make_splits() gpus = parse_use_gpu_arg(self.use_gpu, return_device=False) pin_memory = (True if (settings.dl_pin_memory_gpu_training and gpus != 0) else False) # do not remove drop_last=3, skips over small minibatches return ( AnnDataLoader( self.adata, indices=self.train_idx, shuffle=True, drop_last=3, pin_memory=pin_memory, **self.data_loader_kwargs, ), AnnDataLoader( self.adata, indices=self.val_idx, shuffle=True, drop_last=3, pin_memory=pin_memory, **self.data_loader_kwargs, ), AnnDataLoader( self.adata, indices=self.test_idx, shuffle=True, drop_last=3, pin_memory=pin_memory, **self.data_loader_kwargs, ), )
def setup(self, stage: Optional[str] = None): """Split indices in train/test/val sets.""" n = self.adata.n_obs n_train, n_val = validate_data_split(n, self.train_size, self.validation_size) random_state = np.random.RandomState(seed=settings.seed) permutation = random_state.permutation(n) self.val_idx = permutation[:n_val] self.train_idx = permutation[n_val:(n_val + n_train)] self.test_idx = permutation[(n_val + n_train):] gpus, self.device = parse_use_gpu_arg(self.use_gpu, return_device=True) self.pin_memory = (True if (settings.dl_pin_memory_gpu_training and gpus != 0) else False)
def __init__( self, model: BaseModelClass, training_plan: pl.LightningModule, data_splitter: Union[SemiSupervisedDataSplitter, DataSplitter], max_epochs: int, use_gpu: Optional[Union[str, int, bool]] = None, **trainer_kwargs, ): self.training_plan = training_plan self.data_splitter = data_splitter self.model = model gpus, device = parse_use_gpu_arg(use_gpu) self.gpus = gpus self.device = device self.trainer = Trainer(max_epochs=max_epochs, gpus=gpus, **trainer_kwargs)
def __call__(self, remake_splits=False): if remake_splits: self.train_idx, self.test_idx, self.val_idx = self.make_splits() gpus = parse_use_gpu_arg(self.use_gpu, return_device=False) pin_memory = (True if (settings.dl_pin_memory_gpu_training and gpus != 0) else False) if len(self._labeled_indices) != 0: data_loader_class = SemiSupervisedDataLoader dl_kwargs = { "unlabeled_category": self.unlabeled_category, "n_samples_per_label": self.n_samples_per_label, } else: data_loader_class = AnnDataLoader dl_kwargs = {} dl_kwargs.update(self.data_loader_kwargs) scanvi_train_dl = data_loader_class( self.adata, indices=self.train_idx, shuffle=True, drop_last=3, pin_memory=pin_memory, **dl_kwargs, ) scanvi_val_dl = data_loader_class( self.adata, indices=self.val_idx, shuffle=True, drop_last=3, pin_memory=pin_memory, **dl_kwargs, ) scanvi_test_dl = data_loader_class( self.adata, indices=self.test_idx, shuffle=True, drop_last=3, pin_memory=pin_memory, **dl_kwargs, ) return scanvi_train_dl, scanvi_val_dl, scanvi_test_dl
def _posterior_quantile(self, q: float = 0.5, batch_size: int = None, use_gpu: bool = None, use_median: bool = False): """ Compute median of the posterior distribution of each parameter pyro models trained without amortised inference. Parameters ---------- q Quantile to compute use_gpu Bool, use gpu? use_median Bool, when q=0.5 use median rather than quantile method of the guide Returns ------- dictionary {variable_name: posterior quantile} """ self.module.eval() gpus, device = parse_use_gpu_arg(use_gpu) if batch_size is None: batch_size = self.adata_manager.adata.n_obs train_dl = AnnDataLoader(self.adata_manager, shuffle=False, batch_size=batch_size) # sample global parameters tensor_dict = next(iter(train_dl)) args, kwargs = self.module._get_fn_args_from_batch(tensor_dict) args = [a.to(device) for a in args] kwargs = {k: v.to(device) for k, v in kwargs.items()} self.to_device(device) if use_median and q == 0.5: means = self.module.guide.median(*args, **kwargs) else: means = self.module.guide.quantiles([q], *args, **kwargs) means = {k: means[k].cpu().detach().numpy() for k in means.keys()} return means
def _posterior_quantile_minibatch(self, q: float = 0.5, batch_size: int = 2048, use_gpu: bool = None, use_median: bool = False): """ Compute median of the posterior distribution of each parameter, separating local (minibatch) variable and global variables, which is necessary when performing amortised inference. Note for developers: requires model class method which lists observation/minibatch plate variables (self.module.model.list_obs_plate_vars()). Parameters ---------- q quantile to compute batch_size number of observations per batch use_gpu Bool, use gpu? use_median Bool, when q=0.5 use median rather than quantile method of the guide Returns ------- dictionary {variable_name: posterior quantile} """ gpus, device = parse_use_gpu_arg(use_gpu) self.module.eval() train_dl = AnnDataLoader(self.adata_manager, shuffle=False, batch_size=batch_size) # sample local parameters i = 0 for tensor_dict in train_dl: args, kwargs = self.module._get_fn_args_from_batch(tensor_dict) args = [a.to(device) for a in args] kwargs = {k: v.to(device) for k, v in kwargs.items()} self.to_device(device) if i == 0: # find plate sites obs_plate_sites = self._get_obs_plate_sites( args, kwargs, return_observed=True) if len(obs_plate_sites) == 0: # if no local variables - don't sample break # find plate dimension obs_plate_dim = list(obs_plate_sites.values())[0] if use_median and q == 0.5: means = self.module.guide.median(*args, **kwargs) else: means = self.module.guide.quantiles([q], *args, **kwargs) means = { k: means[k].cpu().numpy() for k in means.keys() if k in obs_plate_sites } else: if use_median and q == 0.5: means_ = self.module.guide.median(*args, **kwargs) else: means_ = self.module.guide.quantiles([q], *args, **kwargs) means_ = { k: means_[k].cpu().numpy() for k in means_.keys() if k in obs_plate_sites } means = { k: np.concatenate([means[k], means_[k]], axis=obs_plate_dim) for k in means.keys() } i += 1 # sample global parameters tensor_dict = next(iter(train_dl)) args, kwargs = self.module._get_fn_args_from_batch(tensor_dict) args = [a.to(device) for a in args] kwargs = {k: v.to(device) for k, v in kwargs.items()} self.to_device(device) if use_median and q == 0.5: global_means = self.module.guide.median(*args, **kwargs) else: global_means = self.module.guide.quantiles([q], *args, **kwargs) global_means = { k: global_means[k].cpu().numpy() for k in global_means.keys() if k not in obs_plate_sites } for k in global_means.keys(): means[k] = global_means[k] self.module.to(device) return means
def load( cls, dir_path: str, adata: Optional[AnnData] = None, use_gpu: Optional[Union[str, int, bool]] = None, ): """ Instantiate a model from the saved output. Parameters ---------- dir_path Path to saved outputs. adata AnnData organized in the same way as data used to train model. It is not necessary to run :func:`~scvi.data.setup_anndata`, as AnnData is validated against the saved `scvi` setup dictionary. If None, will check for and load anndata saved with the model. use_gpu Load model on default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False). Returns ------- Model with loaded state dictionaries. Examples -------- >>> vae = SCVI.load(save_path, adata) >>> vae.get_latent_representation() """ load_adata = adata is None use_gpu, device = parse_use_gpu_arg(use_gpu) ( scvi_setup_dict, attr_dict, var_names, model_state_dict, new_adata, ) = _load_saved_files(dir_path, load_adata, map_location=device) adata = new_adata if new_adata is not None else adata _validate_var_names(adata, var_names) transfer_anndata_setup(scvi_setup_dict, adata) model = _initialize_model(cls, adata, attr_dict) # set saved attrs for loaded model for attr, val in attr_dict.items(): setattr(model, attr, val) # some Pyro modules with AutoGuides may need one training step try: model.module.load_state_dict(model_state_dict) except RuntimeError as err: if isinstance(model.module, PyroBaseModuleClass): logger.info("Preparing underlying module for load") model.train(max_steps=1) pyro.clear_param_store() model.module.load_state_dict(model_state_dict) else: raise err model.to_device(device) model.module.eval() model._validate_anndata(adata) return model
def load_query_data( cls, adata: AnnData, reference_model: Union[str, BaseModelClass], inplace_subset_query_vars: bool = False, use_gpu: Optional[Union[str, int, bool]] = None, unfrozen: bool = False, freeze_dropout: bool = False, freeze_expression: bool = True, freeze_decoder_first_layer: bool = True, freeze_batchnorm_encoder: bool = True, freeze_batchnorm_decoder: bool = False, freeze_classifier: bool = True, ): """ Online update of a reference model with scArches algorithm [Lotfollahi21]_. Parameters ---------- adata AnnData organized in the same way as data used to train model. It is not necessary to run setup_anndata, as AnnData is validated against the saved `scvi` setup dictionary. reference_model Either an already instantiated model of the same class, or a path to saved outputs for reference model. inplace_subset_query_vars Whether to subset and rearrange query vars inplace based on vars used to train reference model. use_gpu Load model on default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False). unfrozen Override all other freeze options for a fully unfrozen model freeze_dropout Whether to freeze dropout during training freeze_expression Freeze neurons corersponding to expression in first layer freeze_decoder_first_layer Freeze neurons corersponding to first layer in decoder freeze_batchnorm_encoder Whether to freeze batchnorm weight and bias during training for encoder freeze_batchnorm_decoder Whether to freeze batchnorm weight and bias during training for decoder freeze_classifier Whether to freeze classifier completely. Only applies to `SCANVI`. """ use_gpu, device = parse_use_gpu_arg(use_gpu) if isinstance(reference_model, str): ( attr_dict, var_names, load_state_dict, _, ) = _load_saved_files(reference_model, load_adata=False, map_location=device) else: attr_dict = reference_model._get_user_attributes() attr_dict = {a[0]: a[1] for a in attr_dict if a[0][-1] == "_"} var_names = reference_model.adata.var_names load_state_dict = deepcopy(reference_model.module.state_dict()) scvi_setup_dict = attr_dict.pop("scvi_setup_dict_") if inplace_subset_query_vars: logger.debug("Subsetting query vars to reference vars.") adata._inplace_subset_var(var_names) _validate_var_names(adata, var_names) version_split = scvi_setup_dict["scvi_version"].split(".") if version_split[1] < "8" and version_split[0] == "0": warnings.warn( "Query integration should be performed using models trained with version >= 0.8" ) transfer_anndata_setup(scvi_setup_dict, adata, extend_categories=True) model = _initialize_model(cls, adata, attr_dict) # set saved attrs for loaded model for attr, val in attr_dict.items(): setattr(model, attr, val) model.to_device(device) # model tweaking new_state_dict = model.module.state_dict() for key, load_ten in load_state_dict.items(): new_ten = new_state_dict[key] if new_ten.size() == load_ten.size(): continue # new categoricals changed size else: dim_diff = new_ten.size()[-1] - load_ten.size()[-1] fixed_ten = torch.cat([load_ten, new_ten[..., -dim_diff:]], dim=-1) load_state_dict[key] = fixed_ten model.module.load_state_dict(load_state_dict) model.module.eval() _set_params_online_update( model.module, unfrozen=unfrozen, freeze_decoder_first_layer=freeze_decoder_first_layer, freeze_batchnorm_encoder=freeze_batchnorm_encoder, freeze_batchnorm_decoder=freeze_batchnorm_decoder, freeze_dropout=freeze_dropout, freeze_expression=freeze_expression, freeze_classifier=freeze_classifier, ) model.is_trained_ = False return model
def load( cls, dir_path: str, prefix: Optional[str] = None, adata_seq: Optional[AnnData] = None, adata_spatial: Optional[AnnData] = None, use_gpu: Optional[Union[str, int, bool]] = None, ): """ Instantiate a model from the saved output. Parameters ---------- dir_path Path to saved outputs. prefix Prefix of saved file names. adata_seq AnnData organized in the same way as data used to train model. It is not necessary to run :meth:`~scvi.external.GIMVI.setup_anndata`, as AnnData is validated against the saved `scvi` setup dictionary. AnnData must be registered via :meth:`~scvi.external.GIMVI.setup_anndata`. adata_spatial AnnData organized in the same way as data used to train model. If None, will check for and load anndata saved with the model. use_gpu Load model on default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False). Returns ------- Model with loaded state dictionaries. Examples -------- >>> vae = GIMVI.load(adata_seq, adata_spatial, save_path) >>> vae.get_latent_representation() """ _, device = parse_use_gpu_arg(use_gpu) ( attr_dict, seq_var_names, spatial_var_names, model_state_dict, loaded_adata_seq, loaded_adata_spatial, ) = _load_saved_gimvi_files( dir_path, adata_seq is None, adata_spatial is None, prefix=prefix, map_location=device, ) adata_seq = loaded_adata_seq or adata_seq adata_spatial = loaded_adata_spatial or adata_spatial adatas = [adata_seq, adata_spatial] var_names = [seq_var_names, spatial_var_names] for i, adata in enumerate(adatas): saved_var_names = var_names[i] user_var_names = adata.var_names.astype(str) if not np.array_equal(saved_var_names, user_var_names): warnings.warn( "var_names for adata passed in does not match var_names of " "adata used to train the model. For valid results, the vars " "need to be the same and in the same order as the adata used to train the model." ) if "scvi_setup_dicts_" in attr_dict: scvi_setup_dicts = attr_dict.pop("scvi_setup_dicts_") for adata, scvi_setup_dict in zip(adatas, scvi_setup_dicts): cls.register_manager( manager_from_setup_dict(cls, adata, scvi_setup_dict) ) else: registries = attr_dict.pop("registries_") for adata, registry in zip(adatas, registries): if ( _MODEL_NAME_KEY in registry and registry[_MODEL_NAME_KEY] != cls.__name__ ): raise ValueError( "It appears you are loading a model from a different class." ) if _SETUP_KWARGS_KEY not in registry: raise ValueError( "Saved model does not contain original setup inputs. " "Cannot load the original setup." ) cls.setup_anndata( adata, source_registry=registry, **registry[_SETUP_KWARGS_KEY] ) # get the parameters for the class init signiture init_params = attr_dict.pop("init_params_") # new saving and loading, enable backwards compatibility if "non_kwargs" in init_params.keys(): # grab all the parameters execept for kwargs (is a dict) non_kwargs = init_params["non_kwargs"] kwargs = init_params["kwargs"] # expand out kwargs kwargs = {k: v for (i, j) in kwargs.items() for (k, v) in j.items()} else: # grab all the parameters execept for kwargs (is a dict) non_kwargs = { k: v for k, v in init_params.items() if not isinstance(v, dict) } kwargs = {k: v for k, v in init_params.items() if isinstance(v, dict)} kwargs = {k: v for (i, j) in kwargs.items() for (k, v) in j.items()} model = cls(adata_seq, adata_spatial, **non_kwargs, **kwargs) for attr, val in attr_dict.items(): setattr(model, attr, val) model.module.load_state_dict(model_state_dict) model.module.eval() model.to_device(device) return model
def load( cls, dir_path: str, adata_seq: Optional[AnnData] = None, adata_spatial: Optional[AnnData] = None, use_gpu: Optional[Union[str, int, bool]] = None, ): """ Instantiate a model from the saved output. Parameters ---------- adata_seq AnnData organized in the same way as data used to train model. It is not necessary to run :func:`~scvi.data.setup_anndata`, as AnnData is validated against the saved `scvi` setup dictionary. AnnData must be registered via :func:`~scvi.data.setup_anndata`. adata_spatial AnnData organized in the same way as data used to train model. If None, will check for and load anndata saved with the model. dir_path Path to saved outputs. use_gpu Load model on default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False). Returns ------- Model with loaded state dictionaries. Examples -------- >>> vae = GIMVI.load(adata_seq, adata_spatial, save_path) >>> vae.get_latent_representation() """ model_path = os.path.join(dir_path, "model_params.pt") setup_dict_path = os.path.join(dir_path, "attr.pkl") seq_data_path = os.path.join(dir_path, "adata_seq.h5ad") spatial_data_path = os.path.join(dir_path, "adata_spatial.h5ad") seq_var_names_path = os.path.join(dir_path, "var_names_seq.csv") spatial_var_names_path = os.path.join(dir_path, "var_names_spatial.csv") if adata_seq is None and os.path.exists(seq_data_path): adata_seq = read(seq_data_path) elif adata_seq is None and not os.path.exists(seq_data_path): raise ValueError( "Save path contains no saved anndata and no adata was passed." ) if adata_spatial is None and os.path.exists(spatial_data_path): adata_spatial = read(spatial_data_path) elif adata_spatial is None and not os.path.exists(spatial_data_path): raise ValueError( "Save path contains no saved anndata and no adata was passed." ) adatas = [adata_seq, adata_spatial] seq_var_names = np.genfromtxt(seq_var_names_path, delimiter=",", dtype=str) spatial_var_names = np.genfromtxt( spatial_var_names_path, delimiter=",", dtype=str ) var_names = [seq_var_names, spatial_var_names] for i, adata in enumerate(adatas): saved_var_names = var_names[i] user_var_names = adata.var_names.astype(str) if not np.array_equal(saved_var_names, user_var_names): warnings.warn( "var_names for adata passed in does not match var_names of " "adata used to train the model. For valid results, the vars " "need to be the same and in the same order as the adata used to train the model." ) with open(setup_dict_path, "rb") as handle: attr_dict = pickle.load(handle) scvi_setup_dicts = attr_dict.pop("scvi_setup_dicts_") transfer_anndata_setup(scvi_setup_dicts["seq"], adata_seq) transfer_anndata_setup(scvi_setup_dicts["spatial"], adata_spatial) # get the parameters for the class init signiture init_params = attr_dict.pop("init_params_") # new saving and loading, enable backwards compatibility if "non_kwargs" in init_params.keys(): # grab all the parameters execept for kwargs (is a dict) non_kwargs = init_params["non_kwargs"] kwargs = init_params["kwargs"] # expand out kwargs kwargs = {k: v for (i, j) in kwargs.items() for (k, v) in j.items()} else: # grab all the parameters execept for kwargs (is a dict) non_kwargs = { k: v for k, v in init_params.items() if not isinstance(v, dict) } kwargs = {k: v for k, v in init_params.items() if isinstance(v, dict)} kwargs = {k: v for (i, j) in kwargs.items() for (k, v) in j.items()} model = cls(adata_seq, adata_spatial, **non_kwargs, **kwargs) for attr, val in attr_dict.items(): setattr(model, attr, val) _, device = parse_use_gpu_arg(use_gpu) model.module.load_state_dict(torch.load(model_path, map_location=device)) model.module.eval() model.to_device(device) return model
def _posterior_samples_minibatch( self, use_gpu: bool = None, batch_size: Optional[int] = None, **sample_kwargs ): """ Generate samples of the posterior distribution in minibatches. Generate samples of the posterior distribution of each parameter, separating local (minibatch) variables and global variables, which is necessary when performing minibatch inference. Parameters ---------- use_gpu Load model on default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False). batch_size Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. Returns ------- dictionary {variable_name: [array with samples in 0 dimension]} """ samples = dict() _, device = parse_use_gpu_arg(use_gpu) batch_size = batch_size if batch_size is not None else settings.batch_size train_dl = AnnDataLoader(self.adata, shuffle=False, batch_size=batch_size) # sample local parameters i = 0 for tensor_dict in track( train_dl, style="tqdm", description="Sampling local variables, batch: ", ): args, kwargs = self.module._get_fn_args_from_batch(tensor_dict) args = [a.to(device) for a in args] kwargs = {k: v.to(device) for k, v in kwargs.items()} self.to_device(device) if i == 0: return_observed = getattr(sample_kwargs, "return_observed", False) obs_plate_sites = self._get_obs_plate_sites( args, kwargs, return_observed=return_observed ) if len(obs_plate_sites) == 0: # if no local variables - don't sample break obs_plate_dim = list(obs_plate_sites.values())[0] sample_kwargs_obs_plate = sample_kwargs.copy() sample_kwargs_obs_plate[ "return_sites" ] = self._get_obs_plate_return_sites( sample_kwargs["return_sites"], list(obs_plate_sites.keys()) ) sample_kwargs_obs_plate["show_progress"] = False samples = self._get_posterior_samples( args, kwargs, **sample_kwargs_obs_plate ) else: samples_ = self._get_posterior_samples( args, kwargs, **sample_kwargs_obs_plate ) samples = { k: np.array( [ np.concatenate( [samples[k][j], samples_[k][j]], axis=obs_plate_dim, ) for j in range( len(samples[k]) ) # for each sample (in 0 dimension ] ) for k in samples.keys() # for each variable } i += 1 # sample global parameters global_samples = self._get_posterior_samples(args, kwargs, **sample_kwargs) global_samples = { k: v for k, v in global_samples.items() if k not in list(obs_plate_sites.keys()) } for k in global_samples.keys(): samples[k] = global_samples[k] self.module.to(device) return samples
def load( cls, dir_path: str, prefix: Optional[str] = None, adata: Optional[AnnData] = None, use_gpu: Optional[Union[str, int, bool]] = None, ): """ Instantiate a model from the saved output. Parameters ---------- dir_path Path to saved outputs. prefix Prefix of saved file names. adata AnnData organized in the same way as data used to train model. It is not necessary to run setup_anndata, as AnnData is validated against the saved `scvi` setup dictionary. If None, will check for and load anndata saved with the model. use_gpu Load model on default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False). Returns ------- Model with loaded state dictionaries. Examples -------- >>> model = ModelClass.load(save_path, adata) # use the name of the model class used to save >>> model.get_.... """ load_adata = adata is None use_gpu, device = parse_use_gpu_arg(use_gpu) ( attr_dict, var_names, model_state_dict, new_adata, ) = _load_saved_files(dir_path, load_adata, map_location=device, prefix=prefix) adata = new_adata if new_adata is not None else adata _validate_var_names(adata, var_names) # Legacy support for old setup dict format. if "scvi_setup_dict_" in attr_dict: scvi_setup_dict = attr_dict.pop("scvi_setup_dict_") cls.register_manager( manager_from_setup_dict(cls, adata, scvi_setup_dict)) else: registry = attr_dict.pop("registry_") if (_MODEL_NAME_KEY in registry and registry[_MODEL_NAME_KEY] != cls.__name__): raise ValueError( "It appears you are loading a model from a different class." ) if _SETUP_KWARGS_KEY not in registry: raise ValueError( "Saved model does not contain original setup inputs. " "Cannot load the original setup.") # Calling ``setup_anndata`` method with the original arguments passed into # the saved model. This enables simple backwards compatibility in the case of # newly introduced fields or parameters. cls.setup_anndata(adata, source_registry=registry, **registry[_SETUP_KWARGS_KEY]) model = _initialize_model(cls, adata, attr_dict) # some Pyro modules with AutoGuides may need one training step try: model.module.load_state_dict(model_state_dict) except RuntimeError as err: if isinstance(model.module, PyroBaseModuleClass): old_history = model.history_.copy() logger.info("Preparing underlying module for load") model.train(max_steps=1) model.history_ = old_history pyro.clear_param_store() model.module.load_state_dict(model_state_dict) else: raise err model.to_device(device) model.module.eval() model._validate_anndata(adata) return model
def setup(self, stage: Optional[str] = None): """Split indices in train/test/val sets.""" n_labeled_idx = len(self._labeled_indices) n_unlabeled_idx = len(self._unlabeled_indices) if n_labeled_idx != 0: n_labeled_train, n_labeled_val = validate_data_split( n_labeled_idx, self.train_size, self.validation_size) rs = np.random.RandomState(seed=settings.seed) labeled_permutation = rs.choice(self._labeled_indices, len(self._labeled_indices), replace=False) labeled_idx_val = labeled_permutation[:n_labeled_val] labeled_idx_train = labeled_permutation[n_labeled_val:( n_labeled_val + n_labeled_train)] labeled_idx_test = labeled_permutation[(n_labeled_val + n_labeled_train):] else: labeled_idx_test = [] labeled_idx_train = [] labeled_idx_val = [] if n_unlabeled_idx != 0: n_unlabeled_train, n_unlabeled_val = validate_data_split( n_unlabeled_idx, self.train_size, self.validation_size) rs = np.random.RandomState(seed=settings.seed) unlabeled_permutation = rs.choice(self._unlabeled_indices, len(self._unlabeled_indices)) unlabeled_idx_val = unlabeled_permutation[:n_unlabeled_val] unlabeled_idx_train = unlabeled_permutation[n_unlabeled_val:( n_unlabeled_val + n_unlabeled_train)] unlabeled_idx_test = unlabeled_permutation[(n_unlabeled_val + n_unlabeled_train):] else: unlabeled_idx_train = [] unlabeled_idx_val = [] unlabeled_idx_test = [] indices_train = np.concatenate( (labeled_idx_train, unlabeled_idx_train)) indices_val = np.concatenate((labeled_idx_val, unlabeled_idx_val)) indices_test = np.concatenate((labeled_idx_test, unlabeled_idx_test)) self.train_idx = indices_train.astype(int) self.val_idx = indices_val.astype(int) self.test_idx = indices_test.astype(int) gpus = parse_use_gpu_arg(self.use_gpu, return_device=False) self.pin_memory = (True if (settings.dl_pin_memory_gpu_training and gpus != 0) else False) if len(self._labeled_indices) != 0: self.data_loader_class = SemiSupervisedDataLoader dl_kwargs = { "unlabeled_category": self.unlabeled_category, "n_samples_per_label": self.n_samples_per_label, } else: self.data_loader_class = AnnDataLoader dl_kwargs = {} self.data_loader_kwargs.update(dl_kwargs)
def _posterior_quantile_amortised(self, q: float = 0.5, batch_size: int = 2048, use_gpu: bool = True): """ Compute median of the posterior distribution of each parameter, separating local (minibatch) variable and global variables, which is necessary when performing amortised inference. Note for developers: requires model class method which lists observation/minibatch plate variables (self.module.model.list_obs_plate_vars()). Parameters ---------- q quantile to compute batch_size number of observations per batch use_gpu Bool, use gpu? Returns ------- dictionary {variable_name: posterior median} """ gpus, device = parse_use_gpu_arg(use_gpu) self.module.eval() train_dl = AnnDataLoader(self.adata, shuffle=False, batch_size=batch_size) # sample local parameters i = 0 for tensor_dict in train_dl: args, kwargs = self.module._get_fn_args_from_batch(tensor_dict) args = [a.to(device) for a in args] kwargs = {k: v.to(device) for k, v in kwargs.items()} self.to_device(device) if i == 0: means = self.module.guide.quantiles([q], *args, **kwargs) means = { k: means[k].cpu().numpy() for k in means.keys() if k in self.module.model.list_obs_plate_vars()["sites"] } # find plate dimension trace = poutine.trace(self.module.model).get_trace(*args, **kwargs) # print(trace.nodes[self.module.model.list_obs_plate_vars()['name']]) obs_plate = { name: site["cond_indep_stack"][0].dim for name, site in trace.nodes.items() if site["type"] == "sample" if any(f.name == self.module.model.list_obs_plate_vars()["name"] for f in site["cond_indep_stack"]) } else: means_ = self.module.guide.quantiles([q], *args, **kwargs) means_ = { k: means_[k].cpu().numpy() for k in means_.keys() if k in list(self.module.model.list_obs_plate_vars()["sites"].keys()) } means = { k: np.concatenate([means[k], means_[k]], axis=list(obs_plate.values())[0]) for k in means.keys() } i += 1 # sample global parameters tensor_dict = next(iter(train_dl)) args, kwargs = self.module._get_fn_args_from_batch(tensor_dict) args = [a.to(device) for a in args] kwargs = {k: v.to(device) for k, v in kwargs.items()} self.to_device(device) global_means = self.module.guide.quantiles([q], *args, **kwargs) global_means = { k: global_means[k].cpu().numpy() for k in global_means.keys() if k not in list(self.module.model.list_obs_plate_vars()["sites"].keys()) } for k in global_means.keys(): means[k] = global_means[k] self.module.to(device) return means
def train( self, max_epochs: int = 200, use_gpu: Optional[Union[str, int, bool]] = None, kappa: int = 5, train_size: float = 0.9, validation_size: Optional[float] = None, batch_size: int = 128, plan_kwargs: Optional[dict] = None, **kwargs, ): """ Train the model. Parameters ---------- max_epochs Number of passes through the dataset. If `None`, defaults to `np.min([round((20000 / n_cells) * 400), 400])` use_gpu Use default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str, e.g., `'cuda:0'`), or use CPU (if False). kappa Scaling parameter for the discriminator loss. train_size Size of training set in the range [0.0, 1.0]. validation_size Size of the test set. If `None`, defaults to 1 - `train_size`. If `train_size + validation_size < 1`, the remaining cells belong to a test set. batch_size Minibatch size to use during training. plan_kwargs Keyword args for model-specific Pytorch Lightning task. Keyword arguments passed to `train()` will overwrite values present in `plan_kwargs`, when appropriate. **kwargs Other keyword args for :class:`~scvi.train.Trainer`. """ gpus, device = parse_use_gpu_arg(use_gpu) self.trainer = Trainer( max_epochs=max_epochs, gpus=gpus, **kwargs, ) self.train_indices_, self.test_indices_, self.validation_indices_ = [], [], [] train_dls, test_dls, val_dls = [], [], [] for i, adm in enumerate(self.adata_managers.values()): ds = DataSplitter( adm, train_size=train_size, validation_size=validation_size, batch_size=batch_size, use_gpu=use_gpu, ) ds.setup() train_dls.append(ds.train_dataloader()) test_dls.append(ds.test_dataloader()) val = ds.val_dataloader() val_dls.append(val) val.mode = i self.train_indices_.append(ds.train_idx) self.test_indices_.append(ds.test_idx) self.validation_indices_.append(ds.val_idx) train_dl = TrainDL(train_dls) plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict() self._training_plan = GIMVITrainingPlan( self.module, adversarial_classifier=True, scale_adversarial_loss=kappa, **plan_kwargs, ) if train_size == 1.0: # circumvent the empty data loader problem if all dataset used for training self.trainer.fit(self._training_plan, train_dl) else: # accepts list of val dataloaders self.trainer.fit(self._training_plan, train_dl, val_dls) try: self.history_ = self.trainer.logger.history except AttributeError: self.history_ = None self.module.eval() self.to_device(device) self.is_trained_ = True
def load( cls, dir_path: str, prefix: Optional[str] = None, adata: Optional[AnnData] = None, use_gpu: Optional[Union[str, int, bool]] = None, ): """ Instantiate a model from the saved output. Parameters ---------- dir_path Path to saved outputs. prefix Prefix of saved file names. adata AnnData organized in the same way as data used to train model. It is not necessary to run setup_anndata, as AnnData is validated against the saved `scvi` setup dictionary. If None, will check for and load anndata saved with the model. use_gpu Load model on default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False). Returns ------- Model with loaded state dictionaries. Examples -------- >>> model = ModelClass.load(save_path, adata) # use the name of the model class used to save >>> model.get_.... """ load_adata = adata is None use_gpu, device = parse_use_gpu_arg(use_gpu) ( attr_dict, var_names, model_state_dict, new_adata, ) = _load_saved_files(dir_path, load_adata, map_location=device, prefix=prefix) adata = new_adata if new_adata is not None else adata scvi_setup_dict = attr_dict.pop("scvi_setup_dict_") # Filter out keys that are no longer populated by setup_anndata. # TODO(jhong): remove hack with setup_anndata refactor. deprecated_keys = {"local_l_mean", "local_l_var"} scvi_setup_dict["data_registry"] = { k: v for k, v in scvi_setup_dict["data_registry"].items() if k not in deprecated_keys } _validate_var_names(adata, var_names) transfer_anndata_setup(scvi_setup_dict, adata) model = _initialize_model(cls, adata, attr_dict) # set saved attrs for loaded model for attr, val in attr_dict.items(): setattr(model, attr, val) # some Pyro modules with AutoGuides may need one training step try: model.module.load_state_dict(model_state_dict) except RuntimeError as err: if isinstance(model.module, PyroBaseModuleClass): old_history = model.history_.copy() logger.info("Preparing underlying module for load") model.train(max_steps=1) model.history_ = old_history pyro.clear_param_store() model.module.load_state_dict(model_state_dict) else: raise err model.to_device(device) model.module.eval() model._validate_anndata(adata) return model