def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None): r""" Returns the reconstruction loss and the Kullback divergences :param x: tensor of values with shape (batch_size, n_input) :param local_l_mean: tensor of means of the prior distribution of latent variable l with shape (batch_size, 1) :param local_l_var: tensor of variancess of the prior distribution of latent variable l with shape (batch_size, 1) :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size`` :param y: tensor of cell-types labels with shape (batch_size, n_labels) :return: the reconstruction loss and the Kullback divergences :rtype: 2-tuple of :py:class:`torch.FloatTensor` """ # assert self.trained_decoder, "If you train the encoder alone please use the `ratio_loss`" \ # "In `forward`, the KL terms are wrong" px_rate, qz_m, qz_v, z, ql_m, ql_v, library = self.inference( x, batch_index, y) # KL Divergence mean, scale = self.get_prior_params(device=qz_m.device) kl_divergence_z = kl(self.z_encoder.distrib(qz_m, qz_v), self.z_encoder.distrib(mean, scale)) if len(kl_divergence_z.size()) == 2: kl_divergence_z = kl_divergence_z.sum(dim=1) kl_divergence_l = kl( Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var)), ).sum(dim=1) kl_divergence = kl_divergence_z reconst_loss = self.get_reconstruction_loss(x, px_rate) return reconst_loss + kl_divergence_l, kl_divergence
def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None): # Parameters for z latent distribution x_ = x if self.log_variational: x_ = torch.log(1 + x_) # Sampling qz_m, qz_v, z = self.z_encoder(x_) ql_m, ql_v, library = self.l_encoder(x_) px_scale, px_r, px_rate, px_dropout = self.decoder( self.dispersion, z, library, batch_index) reconst_loss = self._reconstruction_loss(x, px_rate, px_r, px_dropout, batch_index, y) # KL Divergence mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(dim=1) kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var))).sum(dim=1) kl_divergence = kl_divergence_z + kl_divergence_l return reconst_loss, kl_divergence
def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None): r""" Returns the reconstruction loss and the Kullback divergences :param x: tensor of values with shape (batch_size, n_input) :param local_l_mean: tensor of means of the prior distribution of latent variable l with shape (batch_size, 1) :param local_l_var: tensor of variancess of the prior distribution of latent variable l with shape (batch_size, 1) :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size`` :param y: tensor of cell-types labels with shape (batch_size, n_labels) :return: the reconstruction loss and the Kullback divergences :rtype: 2-tuple of :py:class:`torch.FloatTensor` """ # Parameters for z latent distribution px_scale, px_r, px_rate, px_dropout, qz_m, qz_v, z, ql_m, ql_v, library = self.inference(x, batch_index, y) reconst_loss = self._reconstruction_loss(x, px_rate, px_r, px_dropout) # KL Divergence mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(dim=1) kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var))).sum(dim=1) kl_divergence = kl_divergence_z return reconst_loss + kl_divergence_l, kl_divergence
def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None): # Prepare for sampling x_ = torch.log(1 + x) ql_m, ql_v, library = self.l_encoder(x_) # Enumerate choices of label ys, xs, library_s, batch_index_s = ( broadcast_labels( y, x, library, batch_index, n_broadcast=self.n_labels ) ) if self.log_variational: xs_ = torch.log(1 + xs) # Sampling qz_m, qz_v, zs = self.z_encoder(xs_, batch_index_s, ys) px_scale, px_r, px_rate, px_dropout = self.decoder(self.dispersion, zs, library_s, batch_index_s, ys) reconst_loss = self._reconstruction_loss(xs, px_rate, px_r, px_dropout, batch_index_s, ys) # KL Divergence mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(dim=1) kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var))).sum(dim=1) return reconst_loss, kl_divergence_z + kl_divergence_l
def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None): is_labelled = False if y is None else True outputs = self.inference(x, batch_index, y) px_r = outputs["px_r"] px_rate = outputs["px_rate"] px_dropout = outputs["px_dropout"] qz1_m = outputs["qz_m"] qz1_v = outputs["qz_v"] z1 = outputs["z"] ql_m = outputs["ql_m"] ql_v = outputs["ql_v"] # Enumerate choices of label ys, z1s = broadcast_labels(y, z1, n_broadcast=self.n_labels) qz2_m, qz2_v, z2 = self.encoder_z2_z1(z1s, ys) pz1_m, pz1_v = self.decoder_z1_z2(z2, ys) reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r, px_dropout) # KL Divergence mean = torch.zeros_like(qz2_m) scale = torch.ones_like(qz2_v) kl_divergence_z2 = kl(Normal(qz2_m, torch.sqrt(qz2_v)), Normal(mean, scale)).sum(dim=1) loss_z1_unweight = -Normal(pz1_m, torch.sqrt(pz1_v)).log_prob(z1s).sum(dim=-1) loss_z1_weight = Normal(qz1_m, torch.sqrt(qz1_v)).log_prob(z1).sum(dim=-1) if not self.use_observed_lib_size: kl_divergence_l = kl( Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var)), ).sum(dim=1) else: kl_divergence_l = 0.0 if is_labelled: return ( reconst_loss + loss_z1_weight + loss_z1_unweight, kl_divergence_z2 + kl_divergence_l, 0.0, ) probs = self.classifier(z1) reconst_loss += loss_z1_weight + ( (loss_z1_unweight).view(self.n_labels, -1).t() * probs).sum(dim=1) kl_divergence = (kl_divergence_z2.view(self.n_labels, -1).t() * probs).sum(dim=1) kl_divergence += kl( Categorical(probs=probs), Categorical(probs=self.y_prior.repeat(probs.size(0), 1)), ) kl_divergence += kl_divergence_l return reconst_loss, kl_divergence, 0.0
def forward( self, x: torch.Tensor, local_l_mean: torch.Tensor, local_l_var: torch.Tensor, batch_index: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, mode: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Return the reconstruction loss and the Kullback divergences :param x: tensor of values with shape ``(batch_size, n_input)`` or ``(batch_size, n_input_fish)`` depending on the mode :param local_l_mean: tensor of means of the prior distribution of latent variable l with shape (batch_size, 1) :param local_l_var: tensor of variances of the prior distribution of latent variable l with shape (batch_size, 1) :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size`` :param y: tensor of cell-types labels with shape (batch_size, n_labels) :param mode: indicates which head/tail to use in the joint network :return: the reconstruction loss and the Kullback divergences """ if mode is None: if len(self.n_input_list) == 1: mode = 0 else: raise Exception("Must provide a mode") qz_m, qz_v, z, ql_m, ql_v, library = self.encode(x, mode) px_scale, px_r, px_rate, px_dropout = self.decode( z, mode, library, batch_index, y ) # mask loss to observed genes mapping_indices = self.indices_mappings[mode] reconstruction_loss = self.reconstruction_loss( x, px_rate[:, mapping_indices], px_r[:, mapping_indices], px_dropout[:, mapping_indices], mode, ) # KL Divergence mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum( dim=1 ) if self.model_library_bools[mode]: kl_divergence_l = kl( Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var)), ).sum(dim=1) else: kl_divergence_l = torch.zeros_like(kl_divergence_z) return reconstruction_loss, kl_divergence_l + kl_divergence_z, 0.0
def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None) -> Tuple[torch.Tensor, torch.Tensor]: """ Returns the reconstruction loss and the KL divergences. Parameters ---------- x tensor of values with shape (batch_size, n_input) local_l_mean tensor of means of the prior distribution of latent variable l with shape (batch_size, 1) local_l_var tensor of variancess of the prior distribution of latent variable l with shape (batch_size, 1) batch_index array that indicates which batch the cells belong to with shape ``batch_size`` (Default value = None) y tensor of cell-types labels with shape (batch_size, n_labels) (Default value = None) Returns ------- type the reconstruction loss and the Kullback divergences """ # Parameters for z latent distribution outputs = self.inference(x, batch_index, y) qz_m = outputs["qz_m"] qz_v = outputs["qz_v"] ql_m = outputs["ql_m"] ql_v = outputs["ql_v"] px_rate = outputs["px_rate"] px_r = outputs["px_r"] px_dropout = outputs["px_dropout"] # KL Divergence mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(dim=1) if not self.use_observed_lib_size: kl_divergence_l = kl( Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var)), ).sum(dim=1) else: kl_divergence_l = 0.0 kl_divergence = kl_divergence_z reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r, px_dropout) return reconst_loss + kl_divergence_l, kl_divergence, 0.0
def forward(self, X1, X2, local_l_mean, local_l_var, local_l_mean1, local_l_var1): result = self.inference(X1, X2) disper_x = result["disper_x"] recon_x1 = result["recon_x1"] dropout_rate = result["dropout_rate"] disper_x2 = result["disper_x2"] recon_x_2 = result["recon_x_2"] dropout_rate_2 = result["dropout_rate_2"] if X1 is not None: mean_l = result["mean_l"] logvar_l = result["logvar_l"] kl_divergence_l = kl(Normal(mean_l, logvar_l), Normal(local_l_mean, torch.sqrt(local_l_var))).sum(dim=1) else: kl_divergence_l = torch.tensor(0.0) if X2 is not None: if self.Type == 'ZINB': mean_l2 = result["mean_l2"] logvar_l2 = result["library2"] kl_divergence_l2 = kl( Normal(mean_l2, logvar_l2), Normal(local_l_mean1, torch.sqrt(local_l_var1))).sum(dim=1) else: kl_divergence_l2 = torch.tensor(0.0) else: kl_divergence_l2 = torch.tensor(0.0) mean_z = result["mean_z"] logvar_z = result["logvar_z"] latent_z = result["latent_z"] if self.penality == "GMM": gamma, mu_c, var_c, pi = self.get_gamma( latent_z) #, self.n_centroids, c_params) kl_divergence_z = GMM_loss(gamma, (mu_c, var_c, pi), (mean_z, logvar_z)) else: mean = torch.zeros_like(mean_z) scale = torch.ones_like(logvar_z) kl_divergence_z = kl(Normal(mean_z, logvar_z), Normal(mean, scale)).sum(dim=1) loss1, loss2 = get_both_recon_loss(X1, recon_x1, disper_x, dropout_rate, X2, recon_x_2, disper_x2, dropout_rate_2, "ZINB", self.Type) return loss1, loss2, kl_divergence_l, kl_divergence_l2, kl_divergence_z
def loss( self, tensors, inference_outputs, generative_outputs, kl_weight: int = 1.0, n_obs: int = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # Parameters for z latent distribution qz_m = inference_outputs["qz_m"] qz_v = inference_outputs["qz_v"] px_rate = generative_outputs["px_rate"] px_r = generative_outputs["px_r"] px_dropout = generative_outputs["px_dropout"] bernoulli_params = generative_outputs["bernoulli_params"] x = tensors[_CONSTANTS.X_KEY] batch_index = tensors[_CONSTANTS.BATCH_KEY] # KL divergences wrt z_n,l_n mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(dim=1) if not self.use_observed_lib_size: ql_m = inference_outputs["ql_m"] ql_v = inference_outputs["ql_v"] ( local_library_log_means, local_library_log_vars, ) = self._compute_local_library_params(batch_index) kl_divergence_l = kl( Normal(ql_m, torch.sqrt(ql_v)), Normal(local_library_log_means, torch.sqrt(local_library_log_vars)), ).sum(dim=1) else: kl_divergence_l = 0.0 # KL divergence wrt Bernoulli parameters kl_divergence_bernoulli = self.compute_global_kl_divergence() # Reconstruction loss reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r, px_dropout, bernoulli_params) kl_global = kl_divergence_bernoulli kl_local_for_warmup = kl_divergence_z kl_local_no_warmup = kl_divergence_l weighted_kl_local = kl_weight * kl_local_for_warmup + kl_local_no_warmup loss = n_obs * torch.mean(reconst_loss + weighted_kl_local) + kl_global kl_local = dict(kl_divergence_l=kl_divergence_l, kl_divergence_z=kl_divergence_z) return LossRecorder(loss, reconst_loss, kl_local, kl_global)
def forward( self, x: torch.Tensor, y: torch.Tensor, local_l_mean_gene: torch.Tensor, local_l_var_gene: torch.Tensor, batch_index: Optional[torch.Tensor] = None, label: Optional[torch.Tensor] = None, ): r""" Returns the reconstruction loss and the Kullback divergences :param x: tensor of values with shape (batch_size, n_input_genes) :param y: tensor of values with shape (batch_size, n_input_proteins) :param local_l_mean_gene: tensor of means of the prior distribution of latent variable l with shape (batch_size, 1) :param local_l_var_gene: tensor of variancess of the prior distribution of latent variable l with shape (batch_size, 1) :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size`` :param label: tensor of cell-types labels with shape (batch_size, n_labels) :return: the reconstruction loss and the Kullback divergences :rtype: 4-tuple of :py:class:`torch.FloatTensor` """ # Parameters for z latent distribution outputs = self.inference(x, y, batch_index, label) qz_m = outputs["qz_m"] qz_v = outputs["qz_v"] ql_m = outputs["ql_m"] ql_v = outputs["ql_v"] px_ = outputs["px_"] py_ = outputs["py_"] reconst_loss_gene, reconst_loss_protein = self.get_reconstruction_loss( x, y, px_, py_ ) # KL Divergence kl_div_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(0, 1)).sum(dim=1) kl_div_l_gene = kl( Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean_gene, torch.sqrt(local_l_var_gene)), ).sum(dim=1) kl_div_back_pro = kl( Normal(py_["back_alpha"], py_["back_beta"]), self.back_mean_prior ).sum(dim=-1) return ( reconst_loss_gene, reconst_loss_protein, kl_div_z, kl_div_l_gene, kl_div_back_pro, )
def forward( self, x: torch.Tensor, local_l_mean: torch.Tensor, local_l_var: torch.Tensor, batch_index: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: r""" Returns the reconstruction loss and the Kullback divergences :param x: tensor of values with shape (batch_size, n_input) :param local_l_mean: tensor of means of the prior distribution of latent variable l with shape (batch_size, 1) :param local_l_var: tensor of variancess of the prior distribution of latent variable l with shape (batch_size, 1) :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size`` :param y: tensor of cell-types labels with shape (batch_size, n_labels) :return: the reconstruction loss and the Kullback divergences :rtype: 2-tuple of :py:class:`torch.FloatTensor` """ # Parameters for z latent distribution outputs = self.inference(x, batch_index, y) qz_m = outputs["qz_m"] qz_v = outputs["qz_v"] ql_m = outputs["ql_m"] ql_v = outputs["ql_v"] px_rate = outputs["px_rate"] px_r = outputs["px_r"] px_dropout = outputs["px_dropout"] bernoulli_params = outputs["bernoulli_params"] # KL divergences wrt z_n,l_n mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum( dim=1 ) kl_divergence_l = kl( Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var)), ).sum(dim=1) # KL divergence wrt Bernoulli parameters kl_divergence_bernoulli = self.compute_global_kl_divergence() # Reconstruction loss reconst_loss = self.get_reconstruction_loss( x, px_rate, px_r, px_dropout, bernoulli_params ) return reconst_loss + kl_divergence_l, kl_divergence_z, kl_divergence_bernoulli
def loss( self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0, ): x = tensors[REGISTRY_KEYS.X_KEY] batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] qz_m = inference_outputs["qz_m"] qz_v = inference_outputs["qz_v"] px_rate = generative_outputs["px_rate"] px_r = generative_outputs["px_r"] px_dropout = generative_outputs["px_dropout"] mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, qz_v.sqrt()), Normal(mean, scale)).sum(dim=1) if not self.use_observed_lib_size: ql_m = inference_outputs["ql_m"] ql_v = inference_outputs["ql_v"] ( local_library_log_means, local_library_log_vars, ) = self._compute_local_library_params(batch_index) kl_divergence_l = kl( Normal(ql_m, ql_v.sqrt()), Normal(local_library_log_means, local_library_log_vars.sqrt()), ).sum(dim=1) else: kl_divergence_l = 0.0 reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r, px_dropout) kl_local_for_warmup = kl_divergence_z kl_local_no_warmup = kl_divergence_l weighted_kl_local = kl_weight * kl_local_for_warmup + kl_local_no_warmup loss = torch.mean(reconst_loss + weighted_kl_local) kl_local = dict( kl_divergence_l=kl_divergence_l, kl_divergence_z=kl_divergence_z ) kl_global = torch.tensor(0.0) return LossRecorder(loss, reconst_loss, kl_local, kl_global)
def loss( self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0, ): kl_weight = self.kl_factor * kl_weight x = tensors[_CONSTANTS.X_KEY] local_l_mean = tensors[_CONSTANTS.LOCAL_L_MEAN_KEY] local_l_var = tensors[_CONSTANTS.LOCAL_L_VAR_KEY] qz_m = inference_outputs["qz_m"] qz_v = inference_outputs["qz_v"] ql_m = inference_outputs["ql_m"] ql_v = inference_outputs["ql_v"] px_rate = generative_outputs["px_rate"] px_r = generative_outputs["px_r"] px_dropout = generative_outputs["px_dropout"] mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum( dim=1 ) kl_divergence_l = kl( Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var)), ).sum(dim=1) reconst_loss = ( -ZeroInflatedNegativeBinomial(mu=px_rate, theta=px_r, zi_logits=px_dropout) .log_prob(x) .sum(dim=-1) ) kl_local_for_warmup = kl_divergence_z kl_local_no_warmup = kl_divergence_l weighted_kl_local = kl_weight * kl_local_for_warmup + kl_local_no_warmup loss = torch.mean(reconst_loss + weighted_kl_local) kl_local = dict( kl_divergence_l=kl_divergence_l, kl_divergence_z=kl_divergence_z ) kl_global = 0.0 return LossRecorder(loss, reconst_loss, kl_local, kl_global)
def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None): # same signature as loss # Parameters for z latent distribution x_ = x if self.log_variational: x_ = torch.log(1 + x_) # Sampling qz_m, qz_v, z = self.z_encoder(x_) ql_m, ql_v, library = self.l_encoder(x_) if self.dispersion == "gene-cell": px_scale, self.px_r, px_rate, px_dropout = self.decoder( self.dispersion, z, library, batch_index) else: # self.dispersion == "gene", "gene-batch", "gene-label" px_scale, px_rate, px_dropout = self.decoder( self.dispersion, z, library, batch_index) if self.dispersion == "gene-label": px_r = F.linear( one_hot(y, self.n_labels), self.px_r) # px_r gets transposed - last dimension is nb genes elif self.dispersion == "gene-batch": px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r) else: px_r = self.px_r # Reconstruction Loss if self.reconstruction_loss == 'zinb': reconst_loss = -log_zinb_positive(x, px_rate, torch.exp(px_r), px_dropout) elif self.reconstruction_loss == 'nb': reconst_loss = -log_nb_positive(x, px_rate, torch.exp(px_r)) # KL Divergence mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(dim=1) kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var))).sum(dim=1) kl_divergence = kl_divergence_z + kl_divergence_l return reconst_loss, kl_divergence
def loss( self, tensors, inference_outputs, generative_outputs, kl_weight: int = 1.0, n_obs: int = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # Parameters for z latent distribution qz_m = inference_outputs["qz_m"] qz_v = inference_outputs["qz_v"] ql_m = inference_outputs["ql_m"] ql_v = inference_outputs["ql_v"] px_rate = generative_outputs["px_rate"] px_r = generative_outputs["px_r"] px_dropout = generative_outputs["px_dropout"] bernoulli_params = generative_outputs["bernoulli_params"] x = tensors[_CONSTANTS.X_KEY] local_l_mean = tensors[_CONSTANTS.LOCAL_L_MEAN_KEY] local_l_var = tensors[_CONSTANTS.LOCAL_L_VAR_KEY] # KL divergences wrt z_n,l_n mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(dim=1) kl_divergence_l = kl( Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var)), ).sum(dim=1) # KL divergence wrt Bernoulli parameters kl_divergence_bernoulli = self.compute_global_kl_divergence() # Reconstruction loss reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r, px_dropout, bernoulli_params) kl_global = kl_divergence_bernoulli kl_local_for_warmup = kl_divergence_l kl_local_no_warmup = kl_divergence_z weighted_kl_local = kl_weight * kl_local_for_warmup + kl_local_no_warmup loss = n_obs * torch.mean(reconst_loss + weighted_kl_local) + kl_global kl_local = dict(kl_divergence_l=kl_divergence_l, kl_divergence_z=kl_divergence_z) return SCVILoss(loss, reconst_loss, kl_local, kl_global)
def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None): is_labelled = False if y is None else True # Prepare for sampling x_ = torch.log(1 + x) ql_m, ql_v, library = self.l_encoder(x_) # Enumerate choices of label ys, xs, library_s, batch_index_s = broadcast_labels( y, x, library, batch_index, n_broadcast=self.n_labels ) # Sampling outputs = self.inference(xs, batch_index_s, ys) px_r = outputs["px_r"] px_rate = outputs["px_rate"] px_dropout = outputs["px_dropout"] qz_m = outputs["qz_m"] qz_v = outputs["qz_v"] reconst_loss = self.get_reconstruction_loss(xs, px_rate, px_r, px_dropout) # KL Divergence mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum( dim=1 ) kl_divergence_l = kl( Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var)), ).sum(dim=1) if is_labelled: return reconst_loss, kl_divergence_z + kl_divergence_l, 0.0 reconst_loss = reconst_loss.view(self.n_labels, -1) probs = self.classifier(x_) reconst_loss = (reconst_loss.t() * probs).sum(dim=1) kl_divergence = (kl_divergence_z.view(self.n_labels, -1).t() * probs).sum(dim=1) kl_divergence += kl( Categorical(probs=probs), Categorical(probs=self.y_prior.repeat(probs.size(0), 1)), ) kl_divergence += kl_divergence_l return reconst_loss, kl_divergence, 0.0
def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None): is_labelled = False if y is None else True # Prepare for sampling x_ = torch.log(1 + x) ql_m, ql_v, library = self.l_encoder(x_) # Enumerate choices of label ys, xs, library_s, batch_index_s = (broadcast_labels( y, x, library, batch_index, n_broadcast=self.n_labels)) if self.log_variational: xs_ = torch.log(1 + xs) # Sampling qz_m, qz_v, zs = self.z_encoder(xs_, ys) px_scale, px_r, px_rate, px_dropout = self.decoder( self.dispersion, zs, library_s, batch_index_s, ys) reconst_loss = self._reconstruction_loss(xs, px_rate, px_r, px_dropout, batch_index_s, ys) # KL Divergence mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(dim=1) kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var))).sum(dim=1) if is_labelled: return reconst_loss, kl_divergence_z + kl_divergence_l reconst_loss = reconst_loss.view(self.n_labels, -1) probs = self.classifier(x_) reconst_loss = (reconst_loss.t() * probs).sum(dim=1) kl_divergence = (kl_divergence_z.view(self.n_labels, -1).t() * probs).sum(dim=1) kl_divergence += kl( Categorical(probs=probs), Categorical(probs=self.y_prior.repeat(probs.size(0), 1))) kl_divergence += kl_divergence_l return reconst_loss, kl_divergence
def loss( self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0, ): x = tensors[_CONSTANTS.X_KEY] local_l_mean = tensors[_CONSTANTS.LOCAL_L_MEAN_KEY] local_l_var = tensors[_CONSTANTS.LOCAL_L_VAR_KEY] qz_m = inference_outputs["qz_m"] qz_v = inference_outputs["qz_v"] ql_m = inference_outputs["ql_m"] ql_v = inference_outputs["ql_v"] px_rate = generative_outputs["px_rate"] px_r = generative_outputs["px_r"] px_dropout = generative_outputs["px_dropout"] mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum( dim=1 ) if not self.use_observed_lib_size: kl_divergence_l = kl( Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var)), ).sum(dim=1) else: kl_divergence_l = 0.0 reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r, px_dropout) kl_local_for_warmup = kl_divergence_l kl_local_no_warmup = kl_divergence_z weighted_kl_local = kl_weight * kl_local_for_warmup + kl_local_no_warmup loss = torch.mean(reconst_loss + weighted_kl_local) kl_local = dict( kl_divergence_l=kl_divergence_l, kl_divergence_z=kl_divergence_z ) kl_global = 0.0 return SCVILoss(loss, reconst_loss, kl_local, kl_global)
def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None, mode="scRNA", weighting=1): r""" Returns the reconstruction loss and the Kullback divergences :param x: tensor of values with shape ``(batch_size, n_input)`` or ``(batch_size, n_input_fish)`` depending on the mode :param local_l_mean: tensor of means of the prior distribution of latent variable l with shape (batch_size, 1) :param local_l_var: tensor of variances of the prior distribution of latent variable l with shape (batch_size, 1) :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size`` :param y: tensor of cell-types labels with shape (batch_size, n_labels) :param mode: string that indicates the type of data we analyse :param weighting: used in none of these methods :return: the reconstruction loss and the Kullback divergences :rtype: 2-tuple of :py:class:`torch.FloatTensor` """ # Parameters for z latent distribution px_scale, px_r, px_rate, px_dropout, qz_m, qz_v, z, ql_m, ql_v, library = self.inference( x, batch_index, y, mode, weighting) # Reconstruction Loss reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r, px_dropout, mode, weighting) # KL Divergence mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(dim=1) if self.model_library: kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var))).sum(dim=1) kl_divergence = kl_divergence_z + kl_divergence_l else: kl_divergence = kl_divergence_z return reconst_loss, kl_divergence
def loss( self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0, ): x = tensors[REGISTRY_KEYS.X_KEY] y = tensors[REGISTRY_KEYS.LABELS_KEY] qz_m = inference_outputs["qz_m"] qz_v = inference_outputs["qz_v"] px_rate = generative_outputs["px_rate"] px_r = generative_outputs["px_r"] mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(dim=1) reconst_loss = -NegativeBinomial(px_rate, logits=px_r).log_prob(x).sum(-1) scaling_factor = self.ct_weight[y.long()[:, 0]] loss = torch.mean(scaling_factor * (reconst_loss + kl_weight * kl_divergence_z)) return LossRecorder(loss, reconst_loss, kl_divergence_z, torch.tensor(0.0))
def loss( self, tensors, inference_outputs, generative_outputs, ): x = tensors[_CONSTANTS.X_KEY] qz_m = inference_outputs["qz_m"] qz_v = inference_outputs["qz_v"] px_logit = generative_outputs["px_logit"] mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum( dim=1 ) reconst_loss = ( -Bernoulli(logits=px_logit) .log_prob(x) .sum(dim=-1) ) loss = torch.mean(reconst_loss + kl_divergence_z) kl_local = dict( kl_divergence_z=kl_divergence_z ) kl_global = 0.0 return LossRecorder(loss, reconst_loss, kl_local, kl_global)
def compute_global_kl_divergence(self) -> torch.Tensor: outputs = self.get_alphas_betas(as_numpy=False) alpha_posterior = outputs["alpha_posterior"] beta_posterior = outputs["beta_posterior"] alpha_prior = outputs["alpha_prior"] beta_prior = outputs["beta_prior"] return kl(Beta(alpha_posterior, beta_posterior), Beta(alpha_prior, beta_prior)).sum()
def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None): is_labelled = False if y is None else True x_ = torch.log(1 + x) qz1_m, qz1_v, z1 = self.z_encoder(x_) ql_m, ql_v, library = self.l_encoder(x_) # Enumerate choices of label ys, z1s = (broadcast_labels(y, z1, n_broadcast=self.n_labels)) qz2_m, qz2_v, z2 = self.encoder_z2_z1(z1s, ys) pz1_m, pz1_v = self.decoder_z1_z2(z2, ys) px_scale, px_r, px_rate, px_dropout = self.decoder( self.dispersion, z1, library, batch_index) reconst_loss = self._reconstruction_loss(x, px_rate, px_r, px_dropout, batch_index, y) # KL Divergence mean = torch.zeros_like(qz2_m) scale = torch.ones_like(qz2_v) kl_divergence_z2 = kl(Normal(qz2_m, torch.sqrt(qz2_v)), Normal(mean, scale)).sum(dim=1) loss_z1_unweight = -Normal(pz1_m, torch.sqrt(pz1_v)).log_prob(z1s).sum(dim=-1) loss_z1_weight = Normal(qz1_m, torch.sqrt(qz1_v)).log_prob(z1).sum(dim=-1) kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var))).sum(dim=1) if is_labelled: return reconst_loss + loss_z1_weight + loss_z1_unweight, kl_divergence_z2 + kl_divergence_l probs = self.classifier(z1) reconst_loss += (loss_z1_weight + ( (loss_z1_unweight).view(self.n_labels, -1).t() * probs).sum(dim=1)) kl_divergence = (kl_divergence_z2.view(self.n_labels, -1).t() * probs).sum(dim=1) kl_divergence += kl( Categorical(probs=probs), Categorical(probs=self.y_prior.repeat(probs.size(0), 1))) kl_divergence += kl_divergence_l return reconst_loss, kl_divergence
def kl(self): w_mus = [weight_mu.view([-1]) for weight_mu in self.weight_mus] b_mus = [bias_mu.view([-1]) for bias_mu in self.bias_mus] mus = torch.cat(w_mus+b_mus) w_logsigs = [weight_logsig.view([-1]) for weight_logsig in self.weight_logsigs] b_logsigs = [bias_logsigs.view([-1]) for bias_logsigs in self.bias_logsigs] sigs = torch.cat(w_logsigs+b_logsigs).exp() q = Normal(mus, sigs) N = Normal(torch.zeros(len(mus), device=mus.device), torch.ones(len(mus), device=mus.device)) return kl(q, N)
def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None): r""" Returns the reconstruction loss and the Kullback divergences :param x: tensor of values with shape (batch_size, n_input) :param local_l_mean: tensor of means of the prior distribution of latent variable l with shape (batch_size, 1) :param local_l_var: tensor of variancess of the prior distribution of latent variable l with shape (batch_size, 1) :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size`` :param y: tensor of cell-types labels with shape (batch_size, n_labels) :return: the reconstruction loss and the Kullback divergences :rtype: 2-tuple of :py:class:`torch.FloatTensor` """ # Parameters for z latent distribution outputs = self.inference(x, batch_index, None) qz_m = outputs["qz_m"] qz_v = outputs["qz_v"] ql_m = outputs["ql_m"] ql_v = outputs["ql_v"] px_rate = outputs["px_rate"] px_r = outputs["px_r"] px_dropout = outputs["px_dropout"] # KL Divergence mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(dim=1) kl_divergence_l = kl( Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var)), ).sum(dim=1) kl_divergence = kl_divergence_z reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r, px_dropout) if self.reconstruction_loss == "mse" or self.reconstruction_loss == "nb": kl_divergence_l = 1.0 print("reconst_loss=%f, kl_divergence=%f" % (torch.mean(reconst_loss), torch.mean(kl_divergence))) return reconst_loss + kl_divergence_l, kl_divergence, 0.0
def forward(self, X, local_l_mean=None, local_l_var=None): result = self.inference(X) latent_z_mu = result["latent_z_mu"] latent_z_logvar = result["latent_z_logvar"] latent_z = result["latent_z"] latent_l_mu = result["latent_l_mu"] latent_l_logvar = result["latent_l_logvar"] imputation = result["imputation"] disperation = result["disperation"] dropoutrate = result["dropoutrate"] # KL Divergence for library factor if local_l_mean is not None: kl_divergence_l = kl(Normal(latent_l_mu, latent_l_logvar), Normal(local_l_mean, torch.sqrt(local_l_var))).sum(dim=1) else: kl_divergence_l = torch.tensor(0.0) # KL Divergence for latent code if self.penality == "GMM": gamma, mu_c, var_c, pi = self.get_gamma( latent_z) #, self.n_centroids, c_params) kl_divergence_z = GMM_loss(gamma, (mu_c, var_c, pi), (latent_z_mu, latent_z_logvar)) else: mean = torch.zeros_like(latent_z_mu) scale = torch.ones_like(latent_z_logvar) kl_divergence_z = kl(Normal(latent_z_mu, latent_z_logvar), Normal(mean, scale)).sum(dim=1) reconst_loss = self.get_reconstruction_loss(X, imputation, disperation, dropoutrate) return reconst_loss, kl_divergence_l, kl_divergence_z
def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None): r""" Returns the reconstruction loss and the Kullback divergences :param x: tensor of values with shape (batch_size, n_input) :param local_l_mean: tensor of means of the prior distribution of latent variable l with shape (batch_size, 1) :param local_l_var: tensor of variancess of the prior distribution of latent variable l with shape (batch_size, 1) :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size`` :param y: tensor of cell-types labels with shape (batch_size, n_labels) :return: the reconstruction loss and the Kullback divergences :rtype: 2-tuple of :py:class:`torch.FloatTensor` """ # Parameters for z latent distribution outputs = self.inference(x, batch_index, y) qz_m = outputs['qz_m'] qz_v = outputs['qz_v'] ql_m = outputs['ql_m'] ql_v = outputs['ql_v'] px_rate = outputs['px_rate'] px_r = outputs['px_r'] px_dropout = outputs['px_dropout'] # KL Divergence mean, scale = self.get_prior_params(device=qz_m.device) kl_divergence_z = kl(self.z_encoder.distrib(qz_m, qz_v), self.z_encoder.distrib(mean, scale)) if len(kl_divergence_z.size()) == 2: kl_divergence_z = kl_divergence_z.sum(dim=1) kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var))).sum(dim=1) kl_divergence = kl_divergence_z reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r, px_dropout) return reconst_loss + kl_divergence_l, kl_divergence
def forward(self, x, y=None): is_labelled = False if y is None else True qz1_m, qz1_v, z1 = self.z_encoder(x) # Enumerate choices of label ys, z1s = ( broadcast_labels( y, z1, n_broadcast=self.n_labels ) ) qz2_m, qz2_v, z2 = self.encoder_z2_z1(z1s, ys) pz1_m, pz1_v = self.decoder_z1_z2(z2, ys) qx_m, qx_v = self.decoder(z1) reconst_loss = self._reconstruction_loss(x, qx_m, qx_v) # KL Divergence mean = torch.zeros_like(qz2_m) scale = torch.ones_like(qz2_v) kl_divergence_z2 = kl(Normal(qz2_m, torch.sqrt(qz2_v)), Normal(mean, scale)).sum(dim=1) loss_z1_unweight = - Normal(pz1_m, torch.sqrt(pz1_v)).log_prob(z1s).sum(dim=-1) loss_z1_weight = Normal(qz1_m, torch.sqrt(qz1_v)).log_prob(z1).sum(dim=-1) if is_labelled: return reconst_loss + loss_z1_weight + loss_z1_unweight, kl_divergence_z2 probs = self.classifier(z1) reconst_loss += (loss_z1_weight + ((loss_z1_unweight).view(self.n_labels, -1).t() * probs).sum(dim=1)) kl_divergence = (kl_divergence_z2.view(self.n_labels, -1).t() * probs).sum(dim=1) kl_divergence += kl(Categorical(probs=probs), Categorical(probs=self.y_prior.repeat(probs.size(0), 1))) return reconst_loss, kl_divergence
def forward(self, x): r""" Returns the reconstruction loss :param x: tensor of values with shape (batch_size, n_input) :return: the reconstruction loss and the Kullback divergences :rtype: 2-tuple of :py:class:`torch.FloatTensor` """ # Parameters for z latent distribution outputs = self.inference(x) qz_m = outputs["qz_m"] qz_v = outputs["qz_v"] px_rate = outputs["px_rate"] px_r = outputs["px_r"] z = outputs["z"] library = outputs["library"] self.encoder_variance.append( np.linalg.norm(qz_v.detach().cpu().numpy(), axis=1)) if self.use_MP: # Message passing likelihood self.initialize_visit() self.initialize_messages(z, self.barcodes, self.n_latent) self.perform_message_passing((self.tree & self.root), z.shape[1], False) mp_lik = self.aggregate_messages_into_leaves_likelihood( z.shape[1], add_prior=True) # Gaussian variational likelihood qz = Normal(qz_m, torch.sqrt(qz_v)).log_prob(z).sum(dim=-1) else: mp_lik = None # scVI Kl Divergence mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) qz = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(dim=1) # Reconstruction Loss if self.reconstruction_loss == "nb": reconst_loss = (-NegativeBinomial( mu=px_rate, theta=px_r).log_prob(x).sum(dim=-1)) elif self.reconstruction_loss == "poisson": reconst_loss = -Poisson(px_rate).log_prob(x).sum(dim=-1) return reconst_loss, qz, mp_lik
def loss( self, tensors, inference_outputs, generative_outputs, ): x = tensors[REGISTRY_KEYS.X_KEY] qz_m = inference_outputs["qz_m"] qz_v = inference_outputs["qz_v"] p = generative_outputs["px"] kld = kl( Normal(qz_m, torch.sqrt(qz_v)), Normal(0, 1), ).sum(dim=1) rl = self.get_reconstruction_loss(p, x) loss = (0.5 * rl + 0.5 * (kld * self.kl_weight)).mean() return LossRecorder(loss, rl, kld)