def compute_expected_per_cell_type(self, samples, adata, ind_x=None): r""" Compute expected expression of each gene in each location for each cell type. Parameters ---------- samples Posterior distribution summary self.samples[f"post_sample_q05}"] (or 'means', 'stds', 'q05', 'q95') produced by export_posterior(). adata Registered anndata object (self.adata). ind_x Location/observation indices for which to compute expected count (if None all locations are used). Returns ------- dictionary with: 1) list with expected expression counts (sparse, shape=(N locations, N genes) for each cell type in the same order as mod.factor_names_; 2) np.array with location indices """ if ind_x is None: ind_x = np.arange(adata.n_obs).astype(int) else: ind_x = ind_x.astype(int) # fetch data x_data = get_from_registry(adata, _CONSTANTS.X_KEY)[ind_x, :] x_data = csr_matrix(x_data) # compute total expected expression obs2sample = get_from_registry(adata, _CONSTANTS.BATCH_KEY) obs2sample = pd.get_dummies(obs2sample.flatten()).values[ind_x, :] mu = np.dot(samples["w_sf"][ind_x, :], self.cell_state_mat.T) * samples["m_g"] + np.dot( obs2sample, samples["s_g_gene_add"] ) # compute conditional expected expression per cell type mu_ct = [ x_data.multiply( ( np.dot( samples["w_sf"][ind_x, i, np.newaxis], self.cell_state_mat.T[np.newaxis, i, :], ) * samples["m_g"] ) / mu ) for i in range(self.n_factors) ] mu_ct = [csr_matrix(x) for x in mu_ct] return {"mu": mu_ct, "ind_x": ind_x}
def plot_QC(self, summary_name: str = "means", use_n_obs: int = 1000): """ Show quality control plots: 1. Reconstruction accuracy to assess if there are any issues with model training. The plot should be roughly diagonal, strong deviations signal problems that need to be investigated. Plotting is slow because expected value of mRNA count needs to be computed from model parameters. Random observations are used to speed up computation. Parameters ---------- summary_name posterior distribution summary to use ('means', 'stds', 'q05', 'q95') Returns ------- """ if getattr(self, "samples", False) is False: raise RuntimeError("self.samples is missing, please run self.export_posterior() first") if use_n_obs is not None: ind_x = np.random.choice(self.adata.n_obs, np.min((use_n_obs, self.adata.n_obs)), replace=False) else: ind_x = None self.expected_nb_param = self.module.model.compute_expected( self.samples[f"post_sample_{summary_name}"], self.adata, ind_x=ind_x ) x_data = get_from_registry(self.adata, _CONSTANTS.X_KEY)[ind_x, :] if issparse(x_data): x_data = np.asarray(x_data.toarray()) self.plot_posterior_mu_vs_data(self.expected_nb_param["mu"], x_data)
def compute_expected_subset(self, samples, adata, fact_ind, cell_ind): r"""Compute expected expression of each gene in each cell that comes from a subset of factors (cell types) or cells. Useful for evaluating how well the model learned expression pattern of all genes in the data. Parameters ---------- samples dictionary with values of the posterior adata registered anndata fact_ind indices of factors/cell types to use cell_ind indices of cells to use """ obs2sample = get_from_registry(adata, _CONSTANTS.BATCH_KEY) obs2sample = pd.get_dummies(obs2sample.flatten()) obs2label = get_from_registry(adata, _CONSTANTS.LABELS_KEY) obs2label = pd.get_dummies(obs2label.flatten()) if self.n_extra_categoricals is not None: extra_categoricals = get_from_registry(adata, _CONSTANTS.CAT_COVS_KEY) obs2extra_categoricals = np.concatenate( [ pd.get_dummies(extra_categoricals.iloc[:, i]) for i, n_cat in enumerate(self.n_extra_categoricals) ], axis=1, ) alpha = 1 / np.power(samples["alpha_g_inverse"], 2) mu = ( np.dot(obs2label[cell_ind, fact_ind], samples["per_cluster_mu_fg"][fact_ind, :]) + np.dot(obs2sample[cell_ind, :], samples["s_g_gene_add"])) * np.dot( obs2sample, samples["detection_mean_y_e"]) # samples["detection_y_c"] if self.n_extra_categoricals is not None: mu = mu * np.dot(obs2extra_categoricals[cell_ind, :], samples["detection_tech_gene_tg"]) return {"mu": mu, "alpha": alpha}
def setup_data_attr(self): """ Sets data attribute. Reduces number of times anndata needs to be accessed """ self.data = { key: get_from_registry(self.adata, key) for key, _ in self.attributes_and_types.items() }
def compute_expected(self, samples, adata, ind_x=None): r"""Compute expected expression of each gene in each cell. Useful for evaluating how well the model learned expression pattern of all genes in the data. Parameters ---------- samples dictionary with values of the posterior adata registered anndata ind_x indices of cells to use (to reduce data size) """ if ind_x is None: ind_x = np.arange(adata.n_obs).astype(int) else: ind_x = ind_x.astype(int) obs2sample = get_from_registry(adata, _CONSTANTS.BATCH_KEY) obs2sample = pd.get_dummies(obs2sample.flatten()).values[ind_x, :] obs2label = get_from_registry(adata, _CONSTANTS.LABELS_KEY) obs2label = pd.get_dummies(obs2label.flatten()).values[ind_x, :] if self.n_extra_categoricals is not None: extra_categoricals = get_from_registry(adata, _CONSTANTS.CAT_COVS_KEY) obs2extra_categoricals = np.concatenate( [ pd.get_dummies(extra_categoricals.iloc[ind_x, i]) for i, n_cat in enumerate(self.n_extra_categoricals) ], axis=1, ) alpha = 1 / np.power(samples["alpha_g_inverse"], 2) mu = (np.dot(obs2label, samples["per_cluster_mu_fg"]) + np.dot(obs2sample, samples["s_g_gene_add"])) * np.dot( obs2sample, samples["detection_mean_y_e"] ) # samples["detection_y_c"][ind_x, :] if self.n_extra_categoricals is not None: mu = mu * np.dot(obs2extra_categoricals, samples["detection_tech_gene_tg"]) return {"mu": mu, "alpha": alpha}
def test_data_format(): # if data was dense np array, check after setup_anndata, data is C_CONTIGUOUS adata = synthetic_iid(run_setup_anndata=False) old_x = adata.X old_pro = adata.obsm["protein_expression"] old_obs = adata.obs adata.X = np.asfortranarray(old_x) adata.obsm["protein_expression"] = np.asfortranarray(old_pro) assert adata.X.flags["C_CONTIGUOUS"] is False assert adata.obsm["protein_expression"].flags["C_CONTIGUOUS"] is False setup_anndata(adata, protein_expression_obsm_key="protein_expression") assert adata.X.flags["C_CONTIGUOUS"] is True assert adata.obsm["protein_expression"].flags["C_CONTIGUOUS"] is True assert np.array_equal(old_x, adata.X) assert np.array_equal(old_pro, adata.obsm["protein_expression"]) assert np.array_equal(old_obs, adata.obs) assert np.array_equal(adata.X, get_from_registry(adata, _CONSTANTS.X_KEY)) assert np.array_equal( adata.obsm["protein_expression"], get_from_registry(adata, _CONSTANTS.PROTEIN_EXP_KEY), ) # if obsm is dataframe, make it C_CONTIGUOUS if it isnt adata = synthetic_iid() pe = np.asfortranarray(adata.obsm["protein_expression"]) adata.obsm["protein_expression"] = pd.DataFrame(pe, index=adata.obs_names) assert adata.obsm["protein_expression"].to_numpy( ).flags["C_CONTIGUOUS"] is False setup_anndata(adata, protein_expression_obsm_key="protein_expression") new_pe = get_from_registry(adata, "protein_expression") assert new_pe.to_numpy().flags["C_CONTIGUOUS"] is True assert np.array_equal(pe, new_pe) assert np.array_equal(adata.X, get_from_registry(adata, _CONSTANTS.X_KEY)) assert np.array_equal( adata.obsm["protein_expression"], get_from_registry(adata, _CONSTANTS.PROTEIN_EXP_KEY), )
def normalise(self, samples, adata): r"""Normalise expression data by estimated technical variables. Parameters ---------- samples dictionary with values of the posterior adata registered anndata """ obs2sample = get_from_registry(adata, _CONSTANTS.BATCH_KEY) obs2sample = pd.get_dummies(obs2sample.flatten()) if self.n_extra_categoricals is not None: extra_categoricals = get_from_registry(adata, _CONSTANTS.CAT_COVS_KEY) obs2extra_categoricals = np.concatenate( [ pd.get_dummies(extra_categoricals.iloc[:, i]) for i, n_cat in enumerate(self.n_extra_categoricals) ], axis=1, ) # get counts matrix corrected = get_from_registry(adata, _CONSTANTS.X_KEY) # normalise per-sample scaling corrected = corrected / np.dot(obs2sample, samples["detection_mean_y_e"]) # normalise per gene effects if self.n_extra_categoricals is not None: corrected = corrected / np.dot(obs2extra_categoricals, samples["detection_tech_gene_tg"]) # remove additive sample effects corrected = corrected - np.dot(obs2sample, samples["s_g_gene_add"]) # set minimum value to 0 for each gene (a hack to avoid negative values) corrected = corrected - corrected.min() return corrected
def __init__( self, model, adata, n_labelled_samples_per_class=50, n_epochs_classifier=1, lr_classification=5 * 1e-3, classification_ratio=50, seed=0, **kwargs ): super().__init__(model, adata, **kwargs) self.model = model self.adata = adata self.n_epochs_classifier = n_epochs_classifier self.lr_classification = lr_classification self.classification_ratio = classification_ratio n_labelled_samples_per_class_array = [ n_labelled_samples_per_class ] * self.adata.uns["_scvi"]["summary_stats"]["n_labels"] labels = np.array(get_from_registry(self.adata, _CONSTANTS.LABELS_KEY)).ravel() np.random.seed(seed=seed) permutation_idx = np.random.permutation(len(labels)) labels = labels[permutation_idx] indices = [] current_nbrs = np.zeros(len(n_labelled_samples_per_class_array)) for idx, (label) in enumerate(labels): label = int(label) if current_nbrs[label] < n_labelled_samples_per_class_array[label]: indices.insert(0, idx) current_nbrs[label] += 1 else: indices.append(idx) indices = np.array(indices) total_labelled = sum(n_labelled_samples_per_class_array) indices_labelled = permutation_idx[indices[:total_labelled]] indices_unlabelled = permutation_idx[indices[total_labelled:]] self.classifier_trainer = ClassifierTrainer( model.classifier, self.adata, metrics_to_monitor=[], silent=True, frequency=0, sampling_model=self.model, ) self.full_dataset = self.create_scvi_dl(shuffle=True) self.labelled_set = self.create_scvi_dl(indices=indices_labelled) self.unlabelled_set = self.create_scvi_dl(indices=indices_unlabelled) for scdl in [self.labelled_set, self.unlabelled_set]: scdl.to_monitor = ["reconstruction_error", "accuracy"]
def compute_expected(self, samples, adata, ind_x=None): r"""Compute expected expression of each gene in each location. Useful for evaluating how well the model learned expression pattern of all genes in the data. """ if ind_x is None: ind_x = np.arange(adata.n_obs).astype(int) else: ind_x = ind_x.astype(int) obs2sample = get_from_registry(adata, _CONSTANTS.BATCH_KEY) obs2sample = pd.get_dummies(obs2sample.flatten()).values[ind_x, :] mu = (np.dot(samples["w_sf"][ind_x, :], self.cell_state_mat.T) + np.dot(obs2sample, samples["s_g_gene_add"]) ) * samples["detection_y_s"][ind_x, :] alpha = np.dot(obs2sample, 1 / np.power(samples["alpha_g_inverse"], 2)) return {"mu": mu, "alpha": alpha, "ind_x": ind_x}
def create_doublets( adata: AnnData, doublet_ratio: int, indices: Optional[Sequence[int]] = None, seed: int = 1, ) -> AnnData: """Simulate doublets. Parameters ---------- adata AnnData object setup with setup_anndata. doublet_ratio Ratio of generated doublets to produce relative to number of cells in adata or length of indices, if not `None`. indices Indices of cells in adata to use. If `None`, all cells are used. seed Seed for reproducibility """ n_obs = adata.n_obs if indices is None else len(indices) num_doublets = doublet_ratio * n_obs # counts can be in many locations, this uses where it was registered in setup x = get_from_registry(adata, _CONSTANTS.X_KEY) if indices is not None: x = x[indices] random_state = np.random.RandomState(seed=seed) parent_inds = random_state.choice(n_obs, size=(num_doublets, 2)) doublets = x[parent_inds[:, 0]] + x[parent_inds[:, 1]] doublets_ad = AnnData(doublets) doublets_ad.var_names = adata.var_names doublets_ad.obs_names = [ "sim_doublet_{}".format(i) for i in range(num_doublets) ] # if adata setup with a layer, need to add layer to doublets adata data_registry = adata.uns["_scvi"]["data_registry"] x_loc = data_registry[_CONSTANTS.X_KEY]["attr_name"] layer = (data_registry[_CONSTANTS.X_KEY]["attr_key"] if x_loc == "layers" else None) if layer is not None: doublets_ad.layers[layer] = doublets return doublets_ad
def test_setup_anndata(): # test regular setup adata = synthetic_iid(run_setup_anndata=False) setup_anndata( adata, batch_key="batch", labels_key="labels", protein_expression_obsm_key="protein_expression", protein_names_uns_key="protein_names", ) np.testing.assert_array_equal( get_from_registry(adata, "batch_indices"), np.array(adata.obs["_scvi_batch"]).reshape((-1, 1)), ) np.testing.assert_array_equal( get_from_registry(adata, "labels"), np.array(adata.obs["labels"].cat.codes).reshape((-1, 1)), ) np.testing.assert_array_equal(get_from_registry(adata, "X"), adata.X) np.testing.assert_array_equal( get_from_registry(adata, "protein_expression"), adata.obsm["protein_expression"], ) np.testing.assert_array_equal(adata.uns["_scvi"]["protein_names"], adata.uns["protein_names"]) # test that error is thrown if its a view: adata = synthetic_iid() with pytest.raises(ValueError): setup_anndata(adata[1]) # If obsm is a df and protein_names_uns_key is None, protein names should be grabbed from column of df adata = synthetic_iid() new_protein_names = np.array(random.sample(range(100), 100)).astype("str") df = pd.DataFrame( adata.obsm["protein_expression"], index=adata.obs_names, columns=new_protein_names, ) adata.obsm["protein_expression"] = df setup_anndata(adata, protein_expression_obsm_key="protein_expression") np.testing.assert_array_equal(adata.uns["_scvi"]["protein_names"], new_protein_names) # test that layer is working properly adata = synthetic_iid() true_x = adata.X adata.layers["X"] = true_x adata.X = np.ones_like(adata.X) setup_anndata(adata, layer="X") np.testing.assert_array_equal(get_from_registry(adata, "X"), true_x) # test that it creates layers and batch if no layers_key is passed adata = synthetic_iid() setup_anndata( adata, protein_expression_obsm_key="protein_expression", protein_names_uns_key="protein_names", ) np.testing.assert_array_equal(get_from_registry(adata, "batch_indices"), np.zeros((adata.shape[0], 1))) np.testing.assert_array_equal(get_from_registry(adata, "labels"), np.zeros((adata.shape[0], 1)))
def custom_objective_hyperopt( space, is_best_training=False, dataset=None, n_epochs=None ): """Custom objective function for advanced autotune tutorial.""" space = defaultdict(dict, space) model_tunable_kwargs = space["model_tunable_kwargs"] trainer_tunable_kwargs = space["trainer_tunable_kwargs"] train_func_tunable_kwargs = space["train_func_tunable_kwargs"] trainer_specific_kwargs = {} model_specific_kwargs = {} train_func_specific_kwargs = {} trainer_specific_kwargs["use_cuda"] = bool(torch.cuda.device_count()) train_func_specific_kwargs["n_epochs"] = n_epochs # add hardcoded parameters # disable scVI progbar trainer_specific_kwargs["silent"] = True trainer_specific_kwargs["frequency"] = 1 # merge params with fixed param precedence model_tunable_kwargs.update(model_specific_kwargs) trainer_tunable_kwargs.update(trainer_specific_kwargs) train_func_tunable_kwargs.update(train_func_specific_kwargs) scanvi = SCANVAE( dataset.uns["_scvi"]["summary_stats"]["n_vars"], dataset.uns["_scvi"]["summary_stats"]["n_batch"], dataset.uns["_scvi"]["summary_stats"]["n_labels"], **model_tunable_kwargs ) trainer_scanvi = SemiSupervisedTrainer(scanvi, dataset, **trainer_tunable_kwargs) batch_indices = get_from_registry(dataset, _CONSTANTS.BATCH_KEY) trainer_scanvi.unlabelled_set = trainer_scanvi.create_scvi_dl( indices=(batch_indices == 1) ) trainer_scanvi.unlabelled_set.to_monitor = ["reconstruction_error", "accuracy"] indices_labelled = batch_indices == 0 if not is_best_training: # compute k-fold accuracy on a 20% validation set k = 5 accuracies = np.zeros(k) indices_labelled = batch_indices == 0 for i in range(k): indices_labelled_train, indices_labelled_val = train_test_split( indices_labelled.nonzero()[0], test_size=0.2 ) trainer_scanvi.labelled_set = trainer_scanvi.create_scvi_dl( indices=indices_labelled_train ) trainer_scanvi.labelled_set.to_monitor = [ "reconstruction_error", "accuracy", ] trainer_scanvi.validation_set = trainer_scanvi.create_scvi_dl( indices=indices_labelled_val ) trainer_scanvi.validation_set.to_monitor = ["accuracy"] trainer_scanvi.train(**train_func_tunable_kwargs) accuracies[i] = trainer_scanvi.history["accuracy_unlabelled_set"][-1] return {"loss": -accuracies.mean(), "space": space, "status": STATUS_OK} else: trainer_scanvi.labelled_set = trainer_scanvi.create_scvi_dl( indices=indices_labelled ) trainer_scanvi.labelled_set.to_monitor = ["reconstruction_error", "accuracy"] trainer_scanvi.train(**train_func_tunable_kwargs) return trainer_scanvi
def __init__( self, model, adata, n_labelled_samples_per_class=50, indices_labelled=None, indices_unlabelled=None, n_epochs_classifier=1, lr_classification=5 * 1e-3, classification_ratio=50, seed=0, scheme: Literal["joint", "alternate", "both"] = "both", **kwargs, ): super().__init__(model, adata, **kwargs) self.model = model self.adata = adata self.n_epochs_classifier = n_epochs_classifier self.lr_classification = lr_classification self.classification_ratio = classification_ratio self.scheme = scheme if scheme == "joint": self.n_epochs_classifier = 0 if indices_labelled is None and indices_unlabelled is None: n_labelled_samples_per_class_array = [ n_labelled_samples_per_class ] * self.adata.uns["_scvi"]["summary_stats"]["n_labels"] labels = np.array( get_from_registry(self.adata, _CONSTANTS.LABELS_KEY)).ravel() np.random.seed(seed=seed) permutation_idx = np.random.permutation(len(labels)) labels = labels[permutation_idx] indices = [] current_nbrs = np.zeros(len(n_labelled_samples_per_class_array)) for idx, (label) in enumerate(labels): label = int(label) if current_nbrs[label] < n_labelled_samples_per_class_array[ label]: indices.insert(0, idx) current_nbrs[label] += 1 else: indices.append(idx) indices = np.array(indices) total_labelled = sum(n_labelled_samples_per_class_array) indices_labelled = permutation_idx[indices[:total_labelled]] indices_unlabelled = permutation_idx[indices[total_labelled:]] class_kwargs = {} if "weight_decay" in kwargs.keys(): class_kwargs["weight_decay"] = kwargs["weight_decay"] self.classifier_trainer = ClassifierTrainer( model.classifier, self.adata, metrics_to_monitor=[], silent=True, frequency=0, sampling_model=self.model, **class_kwargs, ) self.full_dataset = self.create_scvi_dl(shuffle=True) self.labelled_set = self.create_scvi_dl(indices=indices_labelled) self.unlabelled_set = self.create_scvi_dl(indices=indices_unlabelled) for scdl in [self.labelled_set, self.unlabelled_set]: scdl.to_monitor = ["elbo", "reconstruction_error", "accuracy"] # allow to track ELBO self.unlabelled_set.unlabeled = True self.full_dataset.unlabeled = True
def __init__( self, adata: AnnData, cell_state_df: pd.DataFrame, model_class: Optional[PyroModule] = None, detection_mean_per_sample: bool = False, detection_mean_correction: float = 1.0, **model_kwargs, ): # in case any other model was created before that shares the same parameter names. clear_param_store() if not np.all(adata.var_names == cell_state_df.index): raise ValueError( "adata.var_names should match cell_state_df.index, find interecting variables/genes first" ) # add index for each cell (provided to pyro plate for correct minibatching) adata.obs["_indices"] = np.arange(adata.n_obs).astype("int64") scvi.data.register_tensor_from_anndata( adata, registry_key="ind_x", adata_attr_name="obs", adata_key_name="_indices", ) super().__init__(adata) if model_class is None: model_class = LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelGeneAlphaPyroModel self.cell_state_df_ = cell_state_df self.n_factors_ = cell_state_df.shape[1] self.factor_names_ = cell_state_df.columns.values if not detection_mean_per_sample: # compute expected change in sensitivity (m_g in V1 or y_s in V2) sc_total = cell_state_df.sum(0).mean() sp_total = get_from_registry(self.adata, _CONSTANTS.X_KEY).sum(1).mean() get_from_registry(adata, _CONSTANTS.BATCH_KEY) self.detection_mean_ = (sp_total / model_kwargs.get( "N_cells_per_location", 1)) / sc_total self.detection_mean_ = self.detection_mean_ * detection_mean_correction model_kwargs["detection_mean"] = self.detection_mean_ else: # compute expected change in sensitivity (m_g in V1 and y_s in V2) sc_total = cell_state_df.sum(0).mean() sp_total = get_from_registry(self.adata, _CONSTANTS.X_KEY).sum(1) batch = get_from_registry(self.adata, _CONSTANTS.BATCH_KEY).flatten() sp_total = np.array([ sp_total[batch == b].mean() for b in range(self.summary_stats["n_batch"]) ]) self.detection_mean_ = (sp_total / model_kwargs.get( "N_cells_per_location", 1)) / sc_total self.detection_mean_ = self.detection_mean_ * detection_mean_correction model_kwargs["detection_mean"] = self.detection_mean_.reshape( (self.summary_stats["n_batch"], 1)).astype("float32") detection_alpha = model_kwargs.get("detection_alpha", None) if detection_alpha is not None: if type(detection_alpha) is dict: batch_mapping = self.adata.uns["_scvi"][ "categorical_mappings"]["_scvi_batch"]["mapping"] self.detection_alpha_ = pd.Series( detection_alpha)[batch_mapping] model_kwargs[ "detection_alpha"] = self.detection_alpha_.values.reshape( (self.summary_stats["n_batch"], 1)).astype("float32") self.module = Cell2locationBaseModule( model=model_class, n_obs=self.summary_stats["n_cells"], n_vars=self.summary_stats["n_vars"], n_factors=self.n_factors_, n_batch=self.summary_stats["n_batch"], cell_state_mat=self.cell_state_df_.values.astype("float32"), **model_kwargs, ) self._model_summary_string = f'cell2location model with the following params: \nn_factors: {self.n_factors_} \nn_batch: {self.summary_stats["n_batch"]} ' self.init_params_ = self._get_init_params(locals())