def inference(self, x, batch_index=None, y=None, n_samples=1): x_ = x if self.log_variational: x_ = torch.log(1 + x_) # Sampling qz_m, qz_v, z = self.z_encoder(x_, y) px_r, px_rate = self.decoder(self.dispersion, z, batch_index, y) if self.dispersion == "gene-label": px_r = F.linear( one_hot(y, self.n_labels), self.px_r) # px_r gets transposed - last dimension is nb genes elif self.dispersion == "gene-batch": px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r) elif self.dispersion == "gene": px_r = self.px_r px_rate = torch.exp(px_rate) px_r = torch.exp(px_r) return dict( px_r=px_r, px_rate=px_rate, qz_m=qz_m, qz_v=qz_v, z=z, )
def decode( self, z: torch.Tensor, mode: int, library: torch.Tensor, batch_index: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: px_scale, px_r, px_rate, px_dropout = self.decoder( z, mode, library, self.dispersion, batch_index, y ) if self.dispersion == "gene-label": px_r = F.linear(one_hot(y, self.n_labels), self.px_r) elif self.dispersion == "gene-batch": px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r) elif self.dispersion == "gene": px_r = self.px_r.view(1, self.px_r.size(0)) px_r = torch.exp(px_r) px_scale = px_scale / torch.sum( px_scale[:, self.indices_mappings[mode]], dim=1 ).view(-1, 1) px_rate = px_scale * torch.exp(library) return px_scale, px_r, px_rate, px_dropout
def _reconstruction_loss(self, x, px_rate, px_r, px_dropout, batch_index, y, mode="scRNA", weighting=1): if self.dispersion == "gene-label": px_r = F.linear( one_hot(y, self.n_labels), self.px_r) # px_r gets transposed - last dimension is nb genes elif self.dispersion == "gene-batch": px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r) elif self.dispersion == "gene": px_r = self.px_r # Reconstruction Loss if mode == "scRNA": if self.reconstruction_loss == 'zinb': reconst_loss = -log_zinb_positive(x, px_rate, torch.exp(px_r), px_dropout) elif self.reconstruction_loss == 'nb': reconst_loss = -log_nb_positive(x, px_rate, torch.exp(px_r)) else: if self.reconstruction_loss_fish == 'poisson': reconst_loss = -torch.sum(Poisson(px_rate).log_prob(x), dim=1) elif self.reconstruction_loss_fish == 'gaussian': reconst_loss = -torch.sum(Normal(px_rate, 10).log_prob(x), dim=1) return reconst_loss
def reshape_bernoulli( self, bernoulli_params: torch.Tensor, batch_index: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, ) -> torch.Tensor: if self.zero_inflation == "gene-label": one_hot_label = one_hot(y, self.n_labels) # If we sampled several random Bernoulli parameters if len(bernoulli_params.shape) == 2: bernoulli_params = F.linear(one_hot_label, bernoulli_params) else: bernoulli_params_res = [] for sample in range(bernoulli_params.shape[0]): bernoulli_params_res.append( F.linear(one_hot_label, bernoulli_params[sample]) ) bernoulli_params = torch.stack(bernoulli_params_res) elif self.zero_inflation == "gene-batch": one_hot_batch = one_hot(batch_index, self.n_batch) if len(bernoulli_params.shape) == 2: bernoulli_params = F.linear(one_hot_batch, bernoulli_params) # If we sampled several random Bernoulli parameters else: bernoulli_params_res = [] for sample in range(bernoulli_params.shape[0]): bernoulli_params_res.append( F.linear(one_hot_batch, bernoulli_params[sample]) ) bernoulli_params = torch.stack(bernoulli_params_res) return bernoulli_params
def inference(self, x, batch_index=None, y=None, n_samples=1, force_batch=None): x_ = x if self.log_variational: x_ = torch.log(1 + x_) if force_batch is not None: batch_index = torch.zeros_like(batch_index).fill_(force_batch) # Sampling qz_m, qz_v, z = self.z_encoder(x_, y) ql_m, ql_v, library = self.l_encoder(x_) if n_samples > 1: qz_m = qz_m.unsqueeze(0).expand((n_samples, qz_m.size(0), qz_m.size(1))) qz_v = qz_v.unsqueeze(0).expand((n_samples, qz_v.size(0), qz_v.size(1))) z = Normal(qz_m, qz_v.sqrt()).sample() ql_m = ql_m.unsqueeze(0).expand((n_samples, ql_m.size(0), ql_m.size(1))) ql_v = ql_v.unsqueeze(0).expand((n_samples, ql_v.size(0), ql_v.size(1))) library = Normal(ql_m, ql_v.sqrt()).sample() px_scale, px_r, px_rate, px_dropout = self.decoder(self.dispersion, z, library, batch_index, y) if self.dispersion == "gene-label": px_r = F.linear(one_hot(y, self.n_labels), self.px_r) # px_r gets transposed - last dimension is nb genes elif self.dispersion == "gene-batch": px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r) elif self.dispersion == "gene": px_r = self.px_r px_r = torch.exp(px_r) return px_scale, px_r, px_rate, px_dropout, qz_m, qz_v, z, ql_m, ql_v, library
def inference(self, x, batch_index=None, y=None, n_samples=1, transform_batch=None) -> Dict[str, torch.Tensor]: """Helper function used in forward pass """ x_ = x if self.log_variational: x_ = torch.log(1 + x_) # Sampling qz_m, qz_v, z = self.z_encoder(x_, y) ql_m, ql_v, library = self.l_encoder(x_) if n_samples > 1: qz_m = qz_m.unsqueeze(0).expand( (n_samples, qz_m.size(0), qz_m.size(1))) qz_v = qz_v.unsqueeze(0).expand( (n_samples, qz_v.size(0), qz_v.size(1))) # when z is normal, untran_z == z untran_z = Normal(qz_m, qz_v.sqrt()).sample() z = self.z_encoder.z_transformation(untran_z) ql_m = ql_m.unsqueeze(0).expand( (n_samples, ql_m.size(0), ql_m.size(1))) ql_v = ql_v.unsqueeze(0).expand( (n_samples, ql_v.size(0), ql_v.size(1))) library = Normal(ql_m, ql_v.sqrt()).sample() if transform_batch is not None: dec_batch_index = transform_batch * torch.ones_like(batch_index) else: dec_batch_index = batch_index px_scale, px_r, px_rate, px_dropout = self.decoder( self.dispersion, z, library, dec_batch_index, y) if self.dispersion == "gene-label": px_r = F.linear( one_hot(y, self.n_labels), self.px_r) # px_r gets transposed - last dimension is nb genes elif self.dispersion == "gene-batch": px_r = F.linear(one_hot(dec_batch_index, self.n_batch), self.px_r) elif self.dispersion == "gene": px_r = self.px_r px_r = torch.exp(px_r) return dict( px_scale=px_scale, px_r=px_r, px_rate=px_rate, px_dropout=px_dropout, qz_m=qz_m, qz_v=qz_v, z=z, ql_m=ql_m, ql_v=ql_v, library=library, )
def inference(self, x, batch_index=None, y=None, mode="scRNA", weighting=1): x_ = x if self.log_variational: x_ = torch.log(1 + x_) # Sampling if mode == "scRNA": qz_m, qz_v, z = self.z_encoder(x_) library = torch.log(torch.sum(x, dim=1)).view(-1, 1) batch_index = torch.zeros_like(library) if mode == "smFISH": qz_m, qz_v, z = self.z_encoder_fish(x_[:, self.indexes_to_keep]) library = torch.log(torch.sum(x[:, self.indexes_to_keep], dim=1)).view(-1, 1) batch_index = torch.ones_like(library) if self.model_library: if mode == "scRNA": ql_m, ql_v, library = self.l_encoder(x_) elif mode == "smFISH": ql_m, ql_v, library = self.l_encoder_fish( x_[:, self.indexes_to_keep]) else: ql_m, ql_v = None, None qz_m, qz_v, z = self.z_final_encoder(z) px_scale, px_r, px_rate, px_dropout = self.decoder( self.dispersion, z, library, batch_index) # rescaling the expected frequencies if mode == "smFISH": if self.model_library: px_rate = px_scale[:, self.indexes_to_keep] * torch.exp(library) else: px_scale = px_scale[:, self.indexes_to_keep] / torch.sum( px_scale[:, self.indexes_to_keep], dim=1).view(-1, 1) px_rate = px_scale * torch.exp(library) if self.dispersion == "gene-label": px_r = F.linear( one_hot(y, self.n_labels), self.px_r) # px_r gets transposed - last dimension is nb genes elif self.dispersion == "gene-batch": px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r) elif self.dispersion == "gene": px_r = self.px_r return px_scale, px_r, px_rate, px_dropout, qz_m, qz_v, z, ql_m, ql_v, library
def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None): # same signature as loss # Parameters for z latent distribution x_ = x if self.log_variational: x_ = torch.log(1 + x_) # Sampling qz_m, qz_v, z = self.z_encoder(x_) ql_m, ql_v, library = self.l_encoder(x_) if self.dispersion == "gene-cell": px_scale, self.px_r, px_rate, px_dropout = self.decoder( self.dispersion, z, library, batch_index) else: # self.dispersion == "gene", "gene-batch", "gene-label" px_scale, px_rate, px_dropout = self.decoder( self.dispersion, z, library, batch_index) if self.dispersion == "gene-label": px_r = F.linear( one_hot(y, self.n_labels), self.px_r) # px_r gets transposed - last dimension is nb genes elif self.dispersion == "gene-batch": px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r) else: px_r = self.px_r # Reconstruction Loss if self.reconstruction_loss == 'zinb': reconst_loss = -log_zinb_positive(x, px_rate, torch.exp(px_r), px_dropout) elif self.reconstruction_loss == 'nb': reconst_loss = -log_nb_positive(x, px_rate, torch.exp(px_r)) # KL Divergence mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(dim=1) kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var))).sum(dim=1) kl_divergence = kl_divergence_z + kl_divergence_l return reconst_loss, kl_divergence
def _reconstruction_loss(self, x, px_rate, px_r, px_dropout, batch_index, y): if self.dispersion == "gene-label": px_r = F.linear(one_hot(y, self.n_labels), self.px_r) # px_r gets transposed - last dimension is nb genes elif self.dispersion == "gene-batch": px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r) elif self.dispersion == "gene": px_r = self.px_r # Reconstruction Loss if self.reconstruction_loss == 'zinb': reconst_loss = -log_zinb_positive(x, px_rate, torch.exp(px_r), px_dropout) elif self.reconstruction_loss == 'nb': reconst_loss = -log_nb_positive(x, px_rate, torch.exp(px_r)) return reconst_loss
def forward(self, x: torch.Tensor, *cat_list: int): r"""Forward computation on ``x``. :param x: tensor of values with shape ``(n_in,)`` :param cat_list: list of category membership(s) for this sample :return: tensor of shape ``(n_out,)`` :rtype: :py:class:`torch.Tensor` """ one_hot_cat_list = [] # for generality in this list many indices useless. assert len(self.n_cat_list) <= len(cat_list), "nb. categorical args provided doesn't match init. params." for n_cat, cat in zip(self.n_cat_list, cat_list): assert not (n_cat and cat is None), "cat not provided while n_cat != 0 in init. params." if n_cat > 1: # n_cat = 1 will be ignored - no additional information if cat.size(1) != n_cat: one_hot_cat = one_hot(cat, n_cat) else: one_hot_cat = cat # cat has already been one_hot encoded one_hot_cat_list += [one_hot_cat] for layers in self.fc_layers: for layer in layers: if layer is not None: if isinstance(layer, nn.BatchNorm1d): if x.dim() == 3: x = torch.cat([(layer(slice_x)).unsqueeze(0) for slice_x in x], dim=0) else: x = layer(x) else: if isinstance(layer, nn.Linear): if x.dim() == 3: one_hot_cat_list = [o.unsqueeze(0).expand((x.size(0), o.size(0), o.size(1))) for o in one_hot_cat_list] x = torch.cat((x, *one_hot_cat_list), dim=-1) x = layer(x) return x
def inference(self, x, batch_index=None, y=None, n_samples=1, train_library=True): x_ = x if self.log_variational: x_ = torch.log(1 + x_) # Sampling qz_m, qz_v, z = self.z_encoder(x_, y) ql_m, ql_v, library = self.l_encoder(x_) if n_samples > 1: assert not self.z_full_cov # TODO: Check no issues when full cov qz_m = qz_m.unsqueeze(0).expand([n_samples] + list(qz_m.size())) qz_v = qz_v.unsqueeze(0).expand([n_samples] + list(qz_v.size())) ql_m = ql_m.unsqueeze(0).expand([n_samples] + list(ql_m.size())) ql_v = ql_v.unsqueeze(0).expand([n_samples] + list(ql_v.size())) z = self.z_encoder.sample(qz_m, qz_v) library = self.l_encoder.sample(ql_m, ql_v) # library = torch.clamp(library, max=14) # if (library >= 14).any(): # print('TOTOTATA') if not train_library: library = 1.0 px_scale, px_r, px_rate, px_dropout = self.decoder(self.dispersion, z, library, batch_index, y) if self.dispersion == "gene-label": px_r = F.linear(one_hot(y, self.n_labels), self.px_r) # px_r gets transposed - last dimension is nb genes elif self.dispersion == "gene-batch": px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r) elif self.dispersion == "gene": px_r = self.px_r px_r = torch.exp(px_r) return dict( px_scale=px_scale, px_r=px_r, px_rate=px_rate, px_dropout=px_dropout, qz_m=qz_m, qz_v=qz_v, z=z, ql_m=ql_m, ql_v=ql_v, library=library )
def impute_from_z(self, fixed_batch_indices, fixed_l, sample=False): for tensors in self: sample_batch, local_l_mean, local_l_var, batch_index, label = tensors if not sample: if self.model.log_variational: sample_batch = torch.log(1 + sample_batch) z = [self.model.z_encoder(sample_batch)[0]] else: z = [self.model.sample_from_posterior_z(sample_batch)] px_scale, px_r, px_rate, px_dropout = self.model.decoder( self.model.dispersion, z, fixed_l, fixed_batch_indices) if self.model.dispersion == "gene-label": px_r = F.linear( one_hot(y, self.model.n_labels), self.px_r ) # px_r gets transposed - last dimension is nb genes elif self.dispersion == "gene-batch": px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r) elif self.dispersion == "gene": px_r = self.px_r px_r = torch.exp(px_r) return px_r
def forward(self, x, *os): one_hot_os = [] for i, o in enumerate(os): if o is not None and self.n_cat_list[i]: one_hot_o = o if o.size(1) != self.n_cat_list[i]: one_hot_o = one_hot(o, self.n_cat_list[i]) elif o.size(1) == 1 and self.n_cat_list[i] == 1: one_hot_o = o.type(torch.float32) one_hot_os += [one_hot_o] for layer in self.fc_layers: x = layer(torch.cat((x,) + tuple(one_hot_os), 1)) return x
def loss_discriminator(self, z, batch_index, predict_true_class=True, return_details=True): n_classes = self.gene_dataset.n_batches cls_logits = torch.nn.LogSoftmax(dim=1)(self.discriminator(z)) if predict_true_class: cls_target = one_hot(batch_index, n_classes) else: one_hot_batch = one_hot(batch_index, n_classes) cls_target = torch.zeros_like(one_hot_batch) # place zeroes where true label is cls_target.masked_scatter_( ~one_hot_batch.bool(), torch.ones_like(one_hot_batch) / (n_classes - 1)) l_soft = cls_logits * cls_target loss = -l_soft.sum(dim=1).mean() return loss
def forward(self, dispersion: str, z: torch.Tensor, library: torch.Tensor, *cat_list: int): # The decoder returns values for the parameters of the ZINB distribution p1_ = self.factor_regressor(z) if self.n_batches > 1: one_hot_cat = one_hot(cat_list[0], self.n_batches)[:, :-1] p2_ = self.batch_regressor(one_hot_cat) raw_px_scale = p1_ + p2_ else: raw_px_scale = p1_ px_scale = torch.softmax(raw_px_scale, dim=-1) px_dropout = self.px_dropout_decoder(z) px_rate = torch.exp(library) * px_scale px_r = None return px_scale, px_r, px_rate, px_dropout
def forward(self, x, *cat_list): one_hot_cat_list = [ ] # for generality in this list many indices useless. assert len(self.n_cat_list) <= len( cat_list ), "nb. categorical args provided doesn't match init. params." for n_cat, cat in zip(self.n_cat_list, cat_list): assert not (n_cat and cat is None ), "cat not provided while n_cat != 0 in init. params." if n_cat > 1: # n_cat = 1 will be ignored - no additional information if cat.size(1) != n_cat: one_hot_cat = one_hot(cat, n_cat) else: one_hot_cat = cat # cat has already been one_hot encoded one_hot_cat_list += [one_hot_cat] for layers in self.fc_layers: for layer in layers: if isinstance(layer, nn.Linear): x = torch.cat((x, *one_hot_cat_list), 1) x = layer(x) return x
def forward(self, x, *cat_list): one_hot_cat_list = [] # for generality in this list many indices useless. assert len(self.n_cat_list) <= len(cat_list), "nb. categorical args provided doesn't match init. params." for n_cat, cat in zip(self.n_cat_list, cat_list): assert not (n_cat and cat is None), "cat not provided while n_cat != 0 in init. params." if n_cat > 1: # n_cat = 1 will be ignored - no additional information if cat.size(1) != n_cat: one_hot_cat = one_hot(cat, n_cat) else: one_hot_cat = cat # cat has already been one_hot encoded one_hot_cat_list += [one_hot_cat] for layers in self.fc_layers: for layer in layers: if isinstance(layer, nn.BatchNorm1d) and x.dim() == 3: x = torch.cat([(layer(slice_x)).unsqueeze(0) for slice_x in x], dim=0) else: if isinstance(layer, nn.Linear): if x.dim() == 3: one_hot_cat_list = [o.unsqueeze(0).expand((x.size(0), o.size(0), o.size(1))) for o in one_hot_cat_list] x = torch.cat((x, *one_hot_cat_list), dim=-1) x = layer(x) return x
def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None): is_labelled = False if y is None else True # Prepare for sampling xs, ys = (x, y) # Enumerate choices of label if not is_labelled: ys = enumerate_discrete(xs, self.n_labels) xs = xs.repeat(self.n_labels, 1) if batch_index is not None: batch_index = batch_index.repeat(self.n_labels, 1) local_l_var = local_l_var.repeat(self.n_labels, 1) local_l_mean = local_l_mean.repeat(self.n_labels, 1) else: ys = one_hot(ys, self.n_labels) xs_ = xs if self.log_variational: xs_ = torch.log(1 + xs_) # Sampling qz_m, qz_v, z = self.z_encoder(xs_, ys) ql_m, ql_v, library = self.l_encoder(xs_) if self.dispersion == "gene-cell": px_scale, self.px_r, px_rate, px_dropout = self.decoder( self.dispersion, z, library, batch_index, y=ys) elif self.dispersion == "gene": px_scale, px_rate, px_dropout = self.decoder(self.dispersion, z, library, batch_index, y=ys) # Reconstruction Loss if self.reconstruction_loss == 'zinb': reconst_loss = -log_zinb_positive(xs, px_rate, torch.exp( self.px_r), px_dropout) elif self.reconstruction_loss == 'nb': reconst_loss = -log_nb_positive(xs, px_rate, torch.exp(self.px_r)) # KL Divergence mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(dim=1) kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var))).sum(dim=1) kl_divergence = kl_divergence_z + kl_divergence_l if is_labelled: return reconst_loss, kl_divergence reconst_loss = reconst_loss.view(self.n_labels, -1) kl_divergence = kl_divergence.view(self.n_labels, -1) if self.log_variational: x_ = torch.log(1 + x) probs = self.classifier(x_) reconst_loss = (reconst_loss.t() * probs).sum(dim=1) kl_divergence = (kl_divergence.t() * probs).sum(dim=1) kl_divergence += kl(Multinomial(probs=probs), Multinomial(probs=self.y_prior)) return reconst_loss, kl_divergence
def inference( self, x: torch.Tensor, y: torch.Tensor, batch_index: Optional[torch.Tensor] = None, label: Optional[torch.Tensor] = None, n_samples=1, transform_batch: Optional[int] = None, ) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]: """ Internal helper function to compute necessary inference quantities We use the dictionary ``px_`` to contain the parameters of the ZINB/NB for genes. The rate refers to the mean of the NB, dropout refers to Bernoulli mixing parameters. `scale` refers to the quanity upon which differential expression is performed. For genes, this can be viewed as the mean of the underlying gamma distribution. We use the dictionary ``py_`` to contain the parameters of the Mixture NB distribution for proteins. `rate_fore` refers to foreground mean, while `rate_back` refers to background mean. ``scale`` refers to foreground mean adjusted for background probability and scaled to reside in simplex. ``back_alpha`` and ``back_beta`` are the posterior parameters for ``rate_back``. ``fore_scale`` is the scaling factor that enforces `rate_fore` > `rate_back`. ``px_["r"]`` and ``py_["r"]`` are the inverse dispersion parameters for genes and protein, respectively. """ x_ = x y_ = y if self.log_variational: x_ = torch.log(1 + x_) y_ = torch.log(1 + y_) # Sampling - Encoder gets concatenated genes + proteins qz_m, qz_v, ql_m, ql_v, latent, untran_latent = self.encoder( torch.cat((x_, y_), dim=-1), batch_index) z = latent["z"] library_gene = latent["l"] untran_z = untran_latent["z"] untran_l = untran_latent["l"] if n_samples > 1: qz_m = qz_m.unsqueeze(0).expand( (n_samples, qz_m.size(0), qz_m.size(1))) qz_v = qz_v.unsqueeze(0).expand( (n_samples, qz_v.size(0), qz_v.size(1))) untran_z = Normal(qz_m, qz_v.sqrt()).sample() z = self.encoder.z_transformation(untran_z) ql_m = ql_m.unsqueeze(0).expand( (n_samples, ql_m.size(0), ql_m.size(1))) ql_v = ql_v.unsqueeze(0).expand( (n_samples, ql_v.size(0), ql_v.size(1))) untran_l = Normal(ql_m, ql_v.sqrt()).sample() library_gene = self.encoder.l_transformation(untran_l) if self.gene_dispersion == "gene-label": # px_r gets transposed - last dimension is nb genes px_r = F.linear(one_hot(label, self.n_labels), self.px_r) elif self.gene_dispersion == "gene-batch": px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r) elif self.gene_dispersion == "gene": px_r = self.px_r px_r = torch.exp(px_r) if self.protein_dispersion == "protein-label": # py_r gets transposed - last dimension is n_proteins py_r = F.linear(one_hot(label, self.n_labels), self.py_r) elif self.protein_dispersion == "protein-batch": py_r = F.linear(one_hot(batch_index, self.n_batch), self.py_r) elif self.protein_dispersion == "protein": py_r = self.py_r py_r = torch.exp(py_r) # Background regularization if self.n_batch > 0: py_back_alpha_prior = F.linear(one_hot(batch_index, self.n_batch), self.background_pro_alpha) py_back_beta_prior = F.linear( one_hot(batch_index, self.n_batch), torch.exp(self.background_pro_log_beta), ) else: py_back_alpha_prior = self.background_pro_alpha py_back_beta_prior = torch.exp(self.background_pro_log_beta) self.back_mean_prior = Normal(py_back_alpha_prior, py_back_beta_prior) if transform_batch is not None: batch_index = torch.ones_like(batch_index) * transform_batch px_, py_, log_pro_back_mean = self.decoder(z, library_gene, batch_index, label) px_["r"] = px_r py_["r"] = py_r return dict( px_=px_, py_=py_, qz_m=qz_m, qz_v=qz_v, z=z, untran_z=untran_z, ql_m=ql_m, ql_v=ql_v, library_gene=library_gene, untran_l=untran_l, log_pro_back_mean=log_pro_back_mean, )
def forward(self, x, o, *os): if o.size(1) != self.n_cat: o = one_hot(o, self.n_cat) for layer in self.fc_layers: x = layer(torch.cat((x, o), 1)) return x
def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None): is_labelled = False if y is None else True xs, ys = (x, y) xs_ = torch.log(1 + xs) qz1_m, qz1_v, z1_ = self.z_encoder(xs_) z1 = z1_ # Enumerate choices of label if not is_labelled: ys = enumerate_discrete(xs, self.n_labels) xs = xs.repeat(self.n_labels, 1) if batch_index is not None: batch_index = batch_index.repeat(self.n_labels, 1) local_l_var = local_l_var.repeat(self.n_labels, 1) local_l_mean = local_l_mean.repeat(self.n_labels, 1) qz1_m = qz1_m.repeat(self.n_labels, 1) qz1_v = qz1_v.repeat(self.n_labels, 1) z1 = z1.repeat(self.n_labels, 1) else: ys = one_hot(ys, self.n_labels) xs_ = torch.log(1 + xs) qz2_m, qz2_v, z2 = self.encoder_z2_z1(z1, ys) pz1_m, pz1_v = self.decoder_z1_z2(z2, ys) # Sampling ql_m, ql_v, library = self.l_encoder(xs_) # let's keep that ind. of y px_scale, px_rate, px_dropout = self.decoder(self.dispersion, z1, library, batch_index) reconst_loss = -log_zinb_positive(xs, px_rate, torch.exp(self.px_r), px_dropout) # KL Divergence mean = torch.zeros_like(qz2_m) scale = torch.ones_like(qz2_v) kl_divergence_z2 = kl(Normal(qz2_m, torch.sqrt(qz2_v)), Normal(mean, scale)).sum(dim=1) loss_z1 = (-Normal(pz1_m, torch.sqrt(pz1_v)).log_prob(z1) + Normal(qz1_m, torch.sqrt(qz1_v)).log_prob(z1)).sum(dim=1) kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var))).sum(dim=1) kl_divergence = kl_divergence_z2 + loss_z1 + kl_divergence_l if is_labelled: return reconst_loss, kl_divergence reconst_loss = reconst_loss.view(self.n_labels, -1) kl_divergence = kl_divergence.view(self.n_labels, -1) probs = self.classifier(z1_) reconst_loss = (reconst_loss.t() * probs).sum(dim=1) kl_divergence = (kl_divergence.t() * probs).sum(dim=1) kl_divergence += kl(Multinomial(probs=probs), Multinomial(probs=self.y_prior)) return reconst_loss, kl_divergence
def inference(self, x, batch_index=None, y=None, n_samples=1): x_ = x if len(x_) != 2: raise ValueError("Input training data should be 2 data types(RNA and ATAC)," "but input was only {}.format(len(x_))" ) x_rna = x_[0] x_atac = x_[1] if self.log_variational: x_rna = torch.log(1 + x_rna) x_atac = torch.log(1 + x_atac) # Sampling qz_rna_m, qz_rna_v, rna_z = self.RNA_encoder(x_rna, y) qz_atac_m, qz_atac_v, atac_z = self.ATAC_encoder(x_atac, y) qz_m, qz_v, z = self.RNA_ATAC_encoder([x_rna, x_atac], y) gamma, mu_c, var_c, pi = self.get_gamma(z) # , self.n_centroids, c_params) index = torch.argmax(gamma,dim=1) #mu_c_max = torch.tensor([]) #var_c_max = torch.tensor([]) #for index1 in range(len(index)): # mu_c_max = torch.cat((mu_c_max, mu_c[index1,:,index[index1]].float()),1) # var_c_max = torch.cat((var_c_max, var_c[index1,:,index[index1]].float()),1) index1 = [i for i in range(len(index))] mu_c_max = mu_c[index1,:,index] var_c_max = var_c[index1,:,index] z_c_max = reparameterize_gaussian(mu_c_max, var_c_max) #decoder p_rna_scale, p_rna_r, p_rna_rate, p_rna_dropout, p_atac_scale, p_atac_r, p_atac_mean, p_atac_dropout\ = self.RNA_ATAC_decoder(z, z_c_max, y) if self.dispersion == "gene-label": p_rna_r = F.linear( one_hot(y, self.n_labels), self.px_r ) # px_r gets transposed - last dimension is nb genes p_atac_r = F.linear( one_hot(y, self.n_labels), self.p_atac_r ) elif self.dispersion == "gene-batch": p_rna_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r) p_atac_r = F.linear(one_hot(batch_index, self.n_batch), self.p_atac_r) elif self.dispersion == "gene": p_rna_r = self.px_r p_atac_r = self.p_atac_r p_rna_r = torch.exp(p_rna_r) p_atac_r = torch.exp(p_atac_r) return dict( p_rna_scale=p_rna_scale, p_rna_r=p_rna_r, p_rna_rate=p_rna_rate, p_rna_dropout=p_rna_dropout, p_atac_scale=p_atac_scale, p_atac_r=p_atac_r, p_atac_mean=p_atac_mean, p_atac_dropout=p_atac_dropout, qz_rna_m=qz_rna_m, qz_rna_v=qz_rna_v, rna_z=rna_z, qz_atac_m=qz_atac_m, qz_atac_v=qz_atac_v, atac_z=atac_z, qz_m=qz_m, qz_v=qz_v, z=z, mu_c=mu_c, var_c=var_c, gamma=gamma, pi=pi, mu_c_max=mu_c_max, var_c_max=var_c_max, z_c_max=z_c_max, )