def loss( self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0, n_obs: int = 1.0, ): x = tensors[_CONSTANTS.X_KEY] px_rate = generative_outputs["px_rate"] px_o = generative_outputs["px_o"] reconst_loss = -NegativeBinomial(px_rate, logits=px_o).log_prob(x).sum(-1) # prior likelihood mean = torch.zeros_like(self.eta) scale = torch.ones_like(self.eta) neg_log_likelihood_prior = -Normal(mean, scale).log_prob( self.eta).sum() if self.prior_weight == "n_obs": # the correct way to reweight observations while performing stochastic optimization loss = n_obs * torch.mean(reconst_loss) + neg_log_likelihood_prior else: # the original way it is done in Stereoscope; we use this option to show reproducibility of their codebase loss = torch.sum(reconst_loss) + neg_log_likelihood_prior return LossRecorder(loss, reconst_loss, torch.zeros((1, )), neg_log_likelihood_prior)
def loss( self, tensors, inference_outputs, generative_outputs, n_obs: int = 1.0, ): # generative_outputs is a dict of the return value from `generative(...)` # assume that `n_obs` is the number of training data points p_x_c = generative_outputs["p_x_c"] gamma = generative_outputs["gamma"] # compute Q # take mean of number of cells and multiply by n_obs (instead of summing n) q_per_cell = torch.sum(gamma * -p_x_c, 1) # third term is log prob of prior terms in Q theta_log = F.log_softmax(self.theta_logit, dim=-1) theta_log_prior = Dirichlet(self.dirichlet_concentration) theta_log_prob = -theta_log_prior.log_prob( torch.exp(theta_log) + THETA_LOWER_BOUND) prior_log_prob = theta_log_prob delta_log_prior = Normal(self.delta_log_mean, self.delta_log_log_scale.exp().sqrt()) delta_log_prob = torch.masked_select( delta_log_prior.log_prob(self.delta_log), (self.rho > 0)) prior_log_prob += -torch.sum(delta_log_prob) loss = (torch.mean(q_per_cell) * n_obs + prior_log_prob) / n_obs return LossRecorder(loss, q_per_cell, torch.zeros_like(q_per_cell), prior_log_prob)
def loss( self, tensors, inference_outputs, generative_outputs, kl_weight: int = 1.0, n_obs: int = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # Parameters for z latent distribution qz_m = inference_outputs["qz_m"] qz_v = inference_outputs["qz_v"] ql_m = inference_outputs["ql_m"] ql_v = inference_outputs["ql_v"] px_rate = generative_outputs["px_rate"] px_r = generative_outputs["px_r"] px_dropout = generative_outputs["px_dropout"] bernoulli_params = generative_outputs["bernoulli_params"] x = tensors[_CONSTANTS.X_KEY] local_l_mean = tensors[_CONSTANTS.LOCAL_L_MEAN_KEY] local_l_var = tensors[_CONSTANTS.LOCAL_L_VAR_KEY] # KL divergences wrt z_n,l_n mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(dim=1) kl_divergence_l = kl( Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var)), ).sum(dim=1) # KL divergence wrt Bernoulli parameters kl_divergence_bernoulli = self.compute_global_kl_divergence() # Reconstruction loss reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r, px_dropout, bernoulli_params) kl_global = kl_divergence_bernoulli kl_local_for_warmup = kl_divergence_l kl_local_no_warmup = kl_divergence_z weighted_kl_local = kl_weight * kl_local_for_warmup + kl_local_no_warmup loss = n_obs * torch.mean(reconst_loss + weighted_kl_local) + kl_global kl_local = dict(kl_divergence_l=kl_divergence_l, kl_divergence_z=kl_divergence_z) return LossRecorder(loss, reconst_loss, kl_local, kl_global)
def loss( self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0, ): x = tensors[_CONSTANTS.X_KEY] local_l_mean = tensors[_CONSTANTS.LOCAL_L_MEAN_KEY] local_l_var = tensors[_CONSTANTS.LOCAL_L_VAR_KEY] qz_m = inference_outputs["qz_m"] qz_v = inference_outputs["qz_v"] ql_m = inference_outputs["ql_m"] ql_v = inference_outputs["ql_v"] px_rate = generative_outputs["px_rate"] px_r = generative_outputs["px_r"] px_dropout = generative_outputs["px_dropout"] mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum( dim=1 ) if not self.use_observed_lib_size: kl_divergence_l = kl( Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var)), ).sum(dim=1) else: kl_divergence_l = 0.0 reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r, px_dropout) kl_local_for_warmup = kl_divergence_l kl_local_no_warmup = kl_divergence_z weighted_kl_local = kl_weight * kl_local_for_warmup + kl_local_no_warmup loss = torch.mean(reconst_loss + weighted_kl_local) kl_local = dict( kl_divergence_l=kl_divergence_l, kl_divergence_z=kl_divergence_z ) kl_global = 0.0 return LossRecorder(loss, reconst_loss, kl_local, kl_global)
def loss( self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0, ): x = tensors[_CONSTANTS.X_KEY] px_rate = generative_outputs["px_rate"] px_o = generative_outputs["px_o"] scaling_factor = generative_outputs["scaling_factor"] reconst_loss = -NegativeBinomial(px_rate, logits=px_o).log_prob(x).sum(-1) loss = torch.mean(scaling_factor * reconst_loss) return LossRecorder(loss, reconst_loss, torch.zeros((1, )), 0.0)
def loss( self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0 ): x = tensors[_CONSTANTS.X_KEY] qz_m = inference_outputs["qz_m"] qz_v = inference_outputs["qz_v"] d = inference_outputs["d"] p = generative_outputs["p"] kld = kl_divergence( Normal(qz_m, torch.sqrt(qz_v)), Normal(0, 1), ).sum(dim=1) f = torch.sigmoid(self.region_factors) if self.region_factors is not None else 1 rl = self.get_reconstruction_loss(p, d, f, x) loss = (rl.sum() + kld * kl_weight).sum() return LossRecorder(loss, rl, kld, kl_global=0.0)
def loss( self, tensors, inference_outputs, generative_outputs, mode: Optional[int] = None, kl_weight=1.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Return the reconstruction loss and the Kullback divergences. Parameters ---------- x tensor of values with shape ``(batch_size, n_input)`` or ``(batch_size, n_input_fish)`` depending on the mode local_l_mean tensor of means of the prior distribution of latent variable l with shape (batch_size, 1) local_l_var tensor of variances of the prior distribution of latent variable l with shape (batch_size, 1) batch_index array that indicates which batch the cells belong to with shape ``batch_size`` y tensor of cell-types labels with shape (batch_size, n_labels) mode indicates which head/tail to use in the joint network Returns ------- the reconstruction loss and the Kullback divergences """ if mode is None: if len(self.n_input_list) == 1: mode = 0 else: raise Exception("Must provide a mode") x = tensors[_CONSTANTS.X_KEY] local_l_mean = tensors[_CONSTANTS.LOCAL_L_MEAN_KEY] local_l_var = tensors[_CONSTANTS.LOCAL_L_VAR_KEY] qz_m = inference_outputs["qz_m"] qz_v = inference_outputs["qz_v"] ql_m = inference_outputs["ql_m"] ql_v = inference_outputs["ql_v"] px_rate = generative_outputs["px_rate"] px_r = generative_outputs["px_r"] px_dropout = generative_outputs["px_dropout"] # mask loss to observed genes mapping_indices = self.indices_mappings[mode] reconstruction_loss = self.reconstruction_loss( x, px_rate[:, mapping_indices], px_r[:, mapping_indices], px_dropout[:, mapping_indices], mode, ) # KL Divergence mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(dim=1) if self.model_library_bools[mode]: kl_divergence_l = kl( Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var)), ).sum(dim=1) else: kl_divergence_l = torch.zeros_like(kl_divergence_z) kl_local = kl_divergence_l + kl_divergence_z kl_global = 0.0 loss = torch.mean(reconstruction_loss + kl_weight * kl_local) * x.size(0) return LossRecorder(loss, reconstruction_loss, kl_local, kl_global)
def loss( self, tensors, inference_outputs, generative_outputs, pro_recons_weight=1.0, # double check these defaults kl_weight=1.0, ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: """ Returns the reconstruction loss and the Kullback divergences. Parameters ---------- x tensor of values with shape ``(batch_size, n_input_genes)`` y tensor of values with shape ``(batch_size, n_input_proteins)`` local_l_mean_gene tensor of means of the prior distribution of latent variable l with shape ``(batch_size, 1)```` local_l_var_gene tensor of variancess of the prior distribution of latent variable l with shape ``(batch_size, 1)`` batch_index array that indicates which batch the cells belong to with shape ``batch_size`` label tensor of cell-types labels with shape (batch_size, n_labels) Returns ------- type the reconstruction loss and the Kullback divergences """ qz_m = inference_outputs["qz_m"] qz_v = inference_outputs["qz_v"] ql_m = inference_outputs["ql_m"] ql_v = inference_outputs["ql_v"] px_ = generative_outputs["px_"] py_ = generative_outputs["py_"] x = tensors[_CONSTANTS.X_KEY] batch_index = tensors[_CONSTANTS.BATCH_KEY] local_l_mean_gene = tensors[_CONSTANTS.LOCAL_L_MEAN_KEY] local_l_var_gene = tensors[_CONSTANTS.LOCAL_L_VAR_KEY] y = tensors[_CONSTANTS.PROTEIN_EXP_KEY] if self.protein_batch_mask is not None: pro_batch_mask_minibatch = torch.zeros_like(y) for b in torch.unique(batch_index): b_indices = (batch_index == b).reshape(-1) pro_batch_mask_minibatch[b_indices] = torch.tensor( self.protein_batch_mask[b.item()].astype(np.float32), device=y.device, ) else: pro_batch_mask_minibatch = None reconst_loss_gene, reconst_loss_protein = self.get_reconstruction_loss( x, y, px_, py_, pro_batch_mask_minibatch) # KL Divergence kl_div_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(0, 1)).sum(dim=1) if not self.use_observed_lib_size: kl_div_l_gene = kl( Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean_gene, torch.sqrt(local_l_var_gene)), ).sum(dim=1) else: kl_div_l_gene = 0.0 kl_div_back_pro_full = kl(Normal(py_["back_alpha"], py_["back_beta"]), self.back_mean_prior) if pro_batch_mask_minibatch is not None: kl_div_back_pro = (pro_batch_mask_minibatch * kl_div_back_pro_full).sum(dim=1) else: kl_div_back_pro = kl_div_back_pro_full.sum(dim=1) loss = torch.mean(reconst_loss_gene + pro_recons_weight * reconst_loss_protein + kl_weight * kl_div_z + kl_div_l_gene + kl_weight * kl_div_back_pro) reconst_losses = dict( reconst_loss_gene=reconst_loss_gene, reconst_loss_protein=reconst_loss_protein, ) kl_local = dict( kl_div_z=kl_div_z, kl_div_l_gene=kl_div_l_gene, kl_div_back_pro=kl_div_back_pro, ) return LossRecorder(loss, reconst_losses, kl_local, kl_global=0.0)
def loss( self, tensors, inference_outputs, generative_ouputs, feed_labels=False, kl_weight=1, labelled_tensors=None, classification_ratio=None, ): px_r = generative_ouputs["px_r"] px_rate = generative_ouputs["px_rate"] px_dropout = generative_ouputs["px_dropout"] qz1_m = inference_outputs["qz_m"] qz1_v = inference_outputs["qz_v"] z1 = inference_outputs["z"] ql_m = inference_outputs["ql_m"] ql_v = inference_outputs["ql_v"] x = tensors[_CONSTANTS.X_KEY] local_l_mean = tensors[_CONSTANTS.LOCAL_L_MEAN_KEY] local_l_var = tensors[_CONSTANTS.LOCAL_L_VAR_KEY] if feed_labels: y = tensors[_CONSTANTS.LABELS_KEY] else: y = None is_labelled = False if y is None else True # Enumerate choices of label ys, z1s = broadcast_labels(y, z1, n_broadcast=self.n_labels) qz2_m, qz2_v, z2 = self.encoder_z2_z1(z1s, ys) pz1_m, pz1_v = self.decoder_z1_z2(z2, ys) reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r, px_dropout) # KL Divergence mean = torch.zeros_like(qz2_m) scale = torch.ones_like(qz2_v) kl_divergence_z2 = kl(Normal(qz2_m, torch.sqrt(qz2_v)), Normal(mean, scale)).sum(dim=1) loss_z1_unweight = -Normal(pz1_m, torch.sqrt(pz1_v)).log_prob(z1s).sum(dim=-1) loss_z1_weight = Normal(qz1_m, torch.sqrt(qz1_v)).log_prob(z1).sum(dim=-1) if not self.use_observed_lib_size: kl_divergence_l = kl( Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var)), ).sum(dim=1) else: kl_divergence_l = 0.0 if is_labelled: loss = reconst_loss + loss_z1_weight + loss_z1_unweight kl_locals = { "kl_divergence_z2": kl_divergence_z2, "kl_divergence_l": kl_divergence_l, } if labelled_tensors is not None: loss += (self.classification_loss(labelled_tensors) * classification_ratio) return LossRecorder(loss, reconst_loss, kl_locals, kl_global=0.0) probs = self.classifier(z1) reconst_loss += loss_z1_weight + ( (loss_z1_unweight).view(self.n_labels, -1).t() * probs).sum(dim=1) kl_divergence = (kl_divergence_z2.view(self.n_labels, -1).t() * probs).sum(dim=1) kl_divergence += kl( Categorical(probs=probs), Categorical(probs=self.y_prior.repeat(probs.size(0), 1)), ) kl_divergence += kl_divergence_l loss = torch.mean(reconst_loss + kl_divergence * kl_weight) if labelled_tensors is not None: loss += self.classification_loss( labelled_tensors) * classification_ratio return LossRecorder(loss, reconst_loss, kl_divergence, kl_global=0.0)