def generative( self, z: torch.Tensor, library: torch.Tensor, batch_index: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, mode: Optional[int] = None, ) -> dict: px_scale, px_r, px_rate, px_dropout = self.decoder( z, mode, library, self.dispersion, batch_index, y) if self.dispersion == "gene-label": px_r = F.linear(one_hot(y, self.n_labels), self.px_r) elif self.dispersion == "gene-batch": px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r) elif self.dispersion == "gene": px_r = self.px_r.view(1, self.px_r.size(0)) px_r = torch.exp(px_r) px_scale = px_scale / torch.sum( px_scale[:, self.indices_mappings[mode]], dim=1).view(-1, 1) px_rate = px_scale * torch.exp(library) return dict(px_scale=px_scale, px_r=px_r, px_rate=px_rate, px_dropout=px_dropout)
def generative( self, z, library, batch_index, cont_covs=None, cat_covs=None, y=None ): """Runs the generative model.""" # TODO: refactor forward function to not rely on y decoder_input = z if cont_covs is None else torch.cat([z, cont_covs], dim=-1) if cat_covs is not None: categorical_input = torch.split(cat_covs, 1, dim=1) else: categorical_input = tuple() px_scale, px_r, px_rate, px_dropout = self.decoder( self.dispersion, decoder_input, library, batch_index, *categorical_input, y ) if self.dispersion == "gene-label": px_r = F.linear( one_hot(y, self.n_labels), self.px_r ) # px_r gets transposed - last dimension is nb genes elif self.dispersion == "gene-batch": px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r) elif self.dispersion == "gene": px_r = self.px_r px_r = torch.exp(px_r) return dict( px_scale=px_scale, px_r=px_r, px_rate=px_rate, px_dropout=px_dropout )
def reshape_bernoulli( self, bernoulli_params: torch.Tensor, batch_index: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, ) -> torch.Tensor: if self.zero_inflation == "gene-label": one_hot_label = one_hot(y, self.n_labels) # If we sampled several random Bernoulli parameters if len(bernoulli_params.shape) == 2: bernoulli_params = F.linear(one_hot_label, bernoulli_params) else: bernoulli_params_res = [] for sample in range(bernoulli_params.shape[0]): bernoulli_params_res.append( F.linear(one_hot_label, bernoulli_params[sample])) bernoulli_params = torch.stack(bernoulli_params_res) elif self.zero_inflation == "gene-batch": one_hot_batch = one_hot(batch_index, self.n_batch) if len(bernoulli_params.shape) == 2: bernoulli_params = F.linear(one_hot_batch, bernoulli_params) # If we sampled several random Bernoulli parameters else: bernoulli_params_res = [] for sample in range(bernoulli_params.shape[0]): bernoulli_params_res.append( F.linear(one_hot_batch, bernoulli_params[sample])) bernoulli_params = torch.stack(bernoulli_params_res) return bernoulli_params
def generative(self, z, library_gene, batch_index, label, cont_covs=None, cat_covs=None): decoder_input = z if cont_covs is None else torch.cat([z, cont_covs], dim=-1) if cat_covs is not None: categorical_input = torch.split(cat_covs, 1, dim=1) else: categorical_input = tuple() px_, py_, log_pro_back_mean = self.decoder(decoder_input, library_gene, batch_index, *categorical_input) if self.gene_dispersion == "gene-label": # px_r gets transposed - last dimension is nb genes px_r = F.linear(one_hot(label, self.n_labels), self.px_r) elif self.gene_dispersion == "gene-batch": px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r) elif self.gene_dispersion == "gene": px_r = self.px_r px_r = torch.exp(px_r) if self.protein_dispersion == "protein-label": # py_r gets transposed - last dimension is n_proteins py_r = F.linear(one_hot(label, self.n_labels), self.py_r) elif self.protein_dispersion == "protein-batch": py_r = F.linear(one_hot(batch_index, self.n_batch), self.py_r) elif self.protein_dispersion == "protein": py_r = self.py_r py_r = torch.exp(py_r) px_["r"] = px_r py_["r"] = py_r return dict( px_=px_, py_=py_, log_pro_back_mean=log_pro_back_mean, )
def loss_adversarial_classifier(self, z, batch_index, predict_true_class=True): n_classes = self.n_output_classifier cls_logits = torch.nn.LogSoftmax(dim=1)(self.adversarial_classifier(z)) if predict_true_class: cls_target = one_hot(batch_index, n_classes) else: one_hot_batch = one_hot(batch_index, n_classes) cls_target = torch.zeros_like(one_hot_batch) # place zeroes where true label is cls_target.masked_scatter_( ~one_hot_batch.bool(), torch.ones_like(one_hot_batch) / (n_classes - 1)) l_soft = cls_logits * cls_target loss = -l_soft.sum(dim=1).mean() return loss
def broadcast_labels(y, *o, n_broadcast=-1): """ Utility for the semi-supervised setting. If y is defined(labelled batch) then one-hot encode the labels (no broadcasting needed) If y is undefined (unlabelled batch) then generate all possible labels (and broadcast other arguments if not None) """ if not len(o): raise ValueError("Broadcast must have at least one reference argument") if y is None: ys = enumerate_discrete(o[0], n_broadcast) new_o = iterate( o, lambda x: x.repeat(n_broadcast, 1) if len(x.size()) == 2 else x.repeat(n_broadcast), ) else: ys = one_hot(y, n_broadcast) new_o = o return (ys,) + new_o
def inference( self, x: torch.Tensor, y: torch.Tensor, batch_index: Optional[torch.Tensor] = None, label: Optional[torch.Tensor] = None, n_samples=1, transform_batch: Optional[int] = None, cont_covs=None, cat_covs=None, ) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]: """ Internal helper function to compute necessary inference quantities. We use the dictionary ``px_`` to contain the parameters of the ZINB/NB for genes. The rate refers to the mean of the NB, dropout refers to Bernoulli mixing parameters. `scale` refers to the quanity upon which differential expression is performed. For genes, this can be viewed as the mean of the underlying gamma distribution. We use the dictionary ``py_`` to contain the parameters of the Mixture NB distribution for proteins. `rate_fore` refers to foreground mean, while `rate_back` refers to background mean. ``scale`` refers to foreground mean adjusted for background probability and scaled to reside in simplex. ``back_alpha`` and ``back_beta`` are the posterior parameters for ``rate_back``. ``fore_scale`` is the scaling factor that enforces `rate_fore` > `rate_back`. ``px_["r"]`` and ``py_["r"]`` are the inverse dispersion parameters for genes and protein, respectively. Parameters ---------- x tensor of values with shape ``(batch_size, n_input_genes)`` y tensor of values with shape ``(batch_size, n_input_proteins)`` batch_index array that indicates which batch the cells belong to with shape ``batch_size`` label tensor of cell-types labels with shape (batch_size, n_labels) n_samples Number of samples to sample from approximate posterior transform_batch If not None, will override batch_index cont_covs Continuous covariates to condition on cat_covs Categorical covariates to condition on """ x_ = x y_ = y if self.use_observed_lib_size: library_gene = x.sum(1).unsqueeze(1) if self.log_variational: x_ = torch.log(1 + x_) y_ = torch.log(1 + y_) if cont_covs is not None and self.encode_covariates is True: encoder_input = torch.cat((x_, y_, cont_covs), dim=-1) else: encoder_input = torch.cat((x_, y_), dim=-1) if cat_covs is not None and self.encode_covariates is True: categorical_input = torch.split(cat_covs, 1, dim=1) else: categorical_input = tuple() qz_m, qz_v, ql_m, ql_v, latent, untran_latent = self.encoder( encoder_input, batch_index, *categorical_input) z = latent["z"] untran_z = untran_latent["z"] untran_l = untran_latent["l"] if not self.use_observed_lib_size: library_gene = latent["l"] if n_samples > 1: qz_m = qz_m.unsqueeze(0).expand( (n_samples, qz_m.size(0), qz_m.size(1))) qz_v = qz_v.unsqueeze(0).expand( (n_samples, qz_v.size(0), qz_v.size(1))) untran_z = Normal(qz_m, qz_v.sqrt()).sample() z = self.encoder.z_transformation(untran_z) ql_m = ql_m.unsqueeze(0).expand( (n_samples, ql_m.size(0), ql_m.size(1))) ql_v = ql_v.unsqueeze(0).expand( (n_samples, ql_v.size(0), ql_v.size(1))) untran_l = Normal(ql_m, ql_v.sqrt()).sample() if self.use_observed_lib_size: library_gene = library_gene.unsqueeze(0).expand( (n_samples, library_gene.size(0), library_gene.size(1))) else: library_gene = self.encoder.l_transformation(untran_l) # Background regularization if self.gene_dispersion == "gene-label": # px_r gets transposed - last dimension is nb genes px_r = F.linear(one_hot(label, self.n_labels), self.px_r) elif self.gene_dispersion == "gene-batch": px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r) elif self.gene_dispersion == "gene": px_r = self.px_r px_r = torch.exp(px_r) if self.protein_dispersion == "protein-label": # py_r gets transposed - last dimension is n_proteins py_r = F.linear(one_hot(label, self.n_labels), self.py_r) elif self.protein_dispersion == "protein-batch": py_r = F.linear(one_hot(batch_index, self.n_batch), self.py_r) elif self.protein_dispersion == "protein": py_r = self.py_r py_r = torch.exp(py_r) if self.n_batch > 0: py_back_alpha_prior = F.linear(one_hot(batch_index, self.n_batch), self.background_pro_alpha) py_back_beta_prior = F.linear( one_hot(batch_index, self.n_batch), torch.exp(self.background_pro_log_beta), ) else: py_back_alpha_prior = self.background_pro_alpha py_back_beta_prior = torch.exp(self.background_pro_log_beta) self.back_mean_prior = Normal(py_back_alpha_prior, py_back_beta_prior) if transform_batch is not None: batch_index = torch.ones_like(batch_index) * transform_batch return dict( qz_m=qz_m, qz_v=qz_v, z=z, untran_z=untran_z, ql_m=ql_m, ql_v=ql_v, library_gene=library_gene, untran_l=untran_l, )
def batch(batch_size, label): labels = torch.ones(batch_size, 1, device=x.device, dtype=torch.long) * label return one_hot(labels, y_dim)