Пример #1
0
    def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None):
        r""" Returns the reconstruction loss and the Kullback divergences

        :param x: tensor of values with shape (batch_size, n_input)
        :param local_l_mean: tensor of means of the prior distribution of latent variable l
         with shape (batch_size, 1)
        :param local_l_var: tensor of variancess of the prior distribution of latent variable l
         with shape (batch_size, 1)
        :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size``
        :param y: tensor of cell-types labels with shape (batch_size, n_labels)
        :return: the reconstruction loss and the Kullback divergences
        :rtype: 2-tuple of :py:class:`torch.FloatTensor`
        """
        # assert self.trained_decoder, "If you train the encoder alone please use the `ratio_loss`" \
        #                              "In `forward`, the KL terms are wrong"

        px_rate, qz_m, qz_v, z, ql_m, ql_v, library = self.inference(
            x, batch_index, y)

        # KL Divergence
        mean, scale = self.get_prior_params(device=qz_m.device)
        kl_divergence_z = kl(self.z_encoder.distrib(qz_m, qz_v),
                             self.z_encoder.distrib(mean, scale))
        if len(kl_divergence_z.size()) == 2:
            kl_divergence_z = kl_divergence_z.sum(dim=1)
        kl_divergence_l = kl(
            Normal(ql_m, torch.sqrt(ql_v)),
            Normal(local_l_mean, torch.sqrt(local_l_var)),
        ).sum(dim=1)
        kl_divergence = kl_divergence_z
        reconst_loss = self.get_reconstruction_loss(x, px_rate)
        return reconst_loss + kl_divergence_l, kl_divergence
Пример #2
0
    def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None):
        # Parameters for z latent distribution
        x_ = x
        if self.log_variational:
            x_ = torch.log(1 + x_)

        # Sampling
        qz_m, qz_v, z = self.z_encoder(x_)
        ql_m, ql_v, library = self.l_encoder(x_)

        px_scale, px_r, px_rate, px_dropout = self.decoder(
            self.dispersion, z, library, batch_index)

        reconst_loss = self._reconstruction_loss(x, px_rate, px_r, px_dropout,
                                                 batch_index, y)

        # KL Divergence
        mean = torch.zeros_like(qz_m)
        scale = torch.ones_like(qz_v)

        kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)),
                             Normal(mean, scale)).sum(dim=1)
        kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)),
                             Normal(local_l_mean,
                                    torch.sqrt(local_l_var))).sum(dim=1)
        kl_divergence = kl_divergence_z + kl_divergence_l

        return reconst_loss, kl_divergence
Пример #3
0
    def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None):
        r""" Returns the reconstruction loss and the Kullback divergences

        :param x: tensor of values with shape (batch_size, n_input)
        :param local_l_mean: tensor of means of the prior distribution of latent variable l
         with shape (batch_size, 1)
        :param local_l_var: tensor of variancess of the prior distribution of latent variable l
         with shape (batch_size, 1)
        :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size``
        :param y: tensor of cell-types labels with shape (batch_size, n_labels)
        :return: the reconstruction loss and the Kullback divergences
        :rtype: 2-tuple of :py:class:`torch.FloatTensor`
        """
        # Parameters for z latent distribution

        px_scale, px_r, px_rate, px_dropout, qz_m, qz_v, z, ql_m, ql_v, library = self.inference(x, batch_index, y)
        reconst_loss = self._reconstruction_loss(x, px_rate, px_r, px_dropout)

        # KL Divergence
        mean = torch.zeros_like(qz_m)
        scale = torch.ones_like(qz_v)

        kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(dim=1)
        kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var))).sum(dim=1)
        kl_divergence = kl_divergence_z

        return reconst_loss + kl_divergence_l, kl_divergence
Пример #4
0
    def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None):

        # Prepare for sampling
        x_ = torch.log(1 + x)
        ql_m, ql_v, library = self.l_encoder(x_)

        # Enumerate choices of label
        ys, xs, library_s, batch_index_s = (
            broadcast_labels(
                y, x, library, batch_index, n_broadcast=self.n_labels
            )
        )

        if self.log_variational:
            xs_ = torch.log(1 + xs)

        # Sampling
        qz_m, qz_v, zs = self.z_encoder(xs_, batch_index_s, ys)

        px_scale, px_r, px_rate, px_dropout = self.decoder(self.dispersion, zs, library_s, batch_index_s, ys)

        reconst_loss = self._reconstruction_loss(xs, px_rate, px_r, px_dropout, batch_index_s, ys)

        # KL Divergence
        mean = torch.zeros_like(qz_m)
        scale = torch.ones_like(qz_v)

        kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(dim=1)
        kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var))).sum(dim=1)

        return reconst_loss, kl_divergence_z + kl_divergence_l
Пример #5
0
    def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None):
        is_labelled = False if y is None else True

        outputs = self.inference(x, batch_index, y)
        px_r = outputs["px_r"]
        px_rate = outputs["px_rate"]
        px_dropout = outputs["px_dropout"]
        qz1_m = outputs["qz_m"]
        qz1_v = outputs["qz_v"]
        z1 = outputs["z"]
        ql_m = outputs["ql_m"]
        ql_v = outputs["ql_v"]

        # Enumerate choices of label
        ys, z1s = broadcast_labels(y, z1, n_broadcast=self.n_labels)
        qz2_m, qz2_v, z2 = self.encoder_z2_z1(z1s, ys)
        pz1_m, pz1_v = self.decoder_z1_z2(z2, ys)

        reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r,
                                                    px_dropout)

        # KL Divergence
        mean = torch.zeros_like(qz2_m)
        scale = torch.ones_like(qz2_v)

        kl_divergence_z2 = kl(Normal(qz2_m, torch.sqrt(qz2_v)),
                              Normal(mean, scale)).sum(dim=1)
        loss_z1_unweight = -Normal(pz1_m,
                                   torch.sqrt(pz1_v)).log_prob(z1s).sum(dim=-1)
        loss_z1_weight = Normal(qz1_m,
                                torch.sqrt(qz1_v)).log_prob(z1).sum(dim=-1)
        if not self.use_observed_lib_size:
            kl_divergence_l = kl(
                Normal(ql_m, torch.sqrt(ql_v)),
                Normal(local_l_mean, torch.sqrt(local_l_var)),
            ).sum(dim=1)
        else:
            kl_divergence_l = 0.0

        if is_labelled:
            return (
                reconst_loss + loss_z1_weight + loss_z1_unweight,
                kl_divergence_z2 + kl_divergence_l,
                0.0,
            )

        probs = self.classifier(z1)
        reconst_loss += loss_z1_weight + (
            (loss_z1_unweight).view(self.n_labels, -1).t() * probs).sum(dim=1)

        kl_divergence = (kl_divergence_z2.view(self.n_labels, -1).t() *
                         probs).sum(dim=1)
        kl_divergence += kl(
            Categorical(probs=probs),
            Categorical(probs=self.y_prior.repeat(probs.size(0), 1)),
        )
        kl_divergence += kl_divergence_l

        return reconst_loss, kl_divergence, 0.0
Пример #6
0
    def forward(
        self,
        x: torch.Tensor,
        local_l_mean: torch.Tensor,
        local_l_var: torch.Tensor,
        batch_index: Optional[torch.Tensor] = None,
        y: Optional[torch.Tensor] = None,
        mode: Optional[int] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Return the reconstruction loss and the Kullback divergences

        :param x: tensor of values with shape ``(batch_size, n_input)``
        or ``(batch_size, n_input_fish)`` depending on the mode
        :param local_l_mean: tensor of means of the prior distribution of latent variable l
        with shape (batch_size, 1)
        :param local_l_var: tensor of variances of the prior distribution of latent variable l
        with shape (batch_size, 1)
        :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size``
        :param y: tensor of cell-types labels with shape (batch_size, n_labels)
        :param mode: indicates which head/tail to use in the joint network
        :return: the reconstruction loss and the Kullback divergences
        """
        if mode is None:
            if len(self.n_input_list) == 1:
                mode = 0
            else:
                raise Exception("Must provide a mode")

        qz_m, qz_v, z, ql_m, ql_v, library = self.encode(x, mode)
        px_scale, px_r, px_rate, px_dropout = self.decode(
            z, mode, library, batch_index, y
        )

        # mask loss to observed genes
        mapping_indices = self.indices_mappings[mode]
        reconstruction_loss = self.reconstruction_loss(
            x,
            px_rate[:, mapping_indices],
            px_r[:, mapping_indices],
            px_dropout[:, mapping_indices],
            mode,
        )

        # KL Divergence
        mean = torch.zeros_like(qz_m)
        scale = torch.ones_like(qz_v)
        kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(
            dim=1
        )

        if self.model_library_bools[mode]:
            kl_divergence_l = kl(
                Normal(ql_m, torch.sqrt(ql_v)),
                Normal(local_l_mean, torch.sqrt(local_l_var)),
            ).sum(dim=1)
        else:
            kl_divergence_l = torch.zeros_like(kl_divergence_z)

        return reconstruction_loss, kl_divergence_l + kl_divergence_z, 0.0
Пример #7
0
    def forward(self,
                x,
                local_l_mean,
                local_l_var,
                batch_index=None,
                y=None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Returns the reconstruction loss and the KL divergences.

        Parameters
        ----------
        x
            tensor of values with shape (batch_size, n_input)
        local_l_mean
            tensor of means of the prior distribution of latent variable l
            with shape (batch_size, 1)
        local_l_var
            tensor of variancess of the prior distribution of latent variable l
            with shape (batch_size, 1)
        batch_index
            array that indicates which batch the cells belong to with shape ``batch_size`` (Default value = None)
        y
            tensor of cell-types labels with shape (batch_size, n_labels) (Default value = None)

        Returns
        -------
        type
            the reconstruction loss and the Kullback divergences
        """
        # Parameters for z latent distribution
        outputs = self.inference(x, batch_index, y)
        qz_m = outputs["qz_m"]
        qz_v = outputs["qz_v"]
        ql_m = outputs["ql_m"]
        ql_v = outputs["ql_v"]
        px_rate = outputs["px_rate"]
        px_r = outputs["px_r"]
        px_dropout = outputs["px_dropout"]

        # KL Divergence
        mean = torch.zeros_like(qz_m)
        scale = torch.ones_like(qz_v)

        kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)),
                             Normal(mean, scale)).sum(dim=1)
        if not self.use_observed_lib_size:
            kl_divergence_l = kl(
                Normal(ql_m, torch.sqrt(ql_v)),
                Normal(local_l_mean, torch.sqrt(local_l_var)),
            ).sum(dim=1)
        else:
            kl_divergence_l = 0.0
        kl_divergence = kl_divergence_z

        reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r,
                                                    px_dropout)

        return reconst_loss + kl_divergence_l, kl_divergence, 0.0
Пример #8
0
    def forward(self, X1, X2, local_l_mean, local_l_var, local_l_mean1,
                local_l_var1):

        result = self.inference(X1, X2)

        disper_x = result["disper_x"]
        recon_x1 = result["recon_x1"]
        dropout_rate = result["dropout_rate"]

        disper_x2 = result["disper_x2"]
        recon_x_2 = result["recon_x_2"]
        dropout_rate_2 = result["dropout_rate_2"]

        if X1 is not None:
            mean_l = result["mean_l"]
            logvar_l = result["logvar_l"]

            kl_divergence_l = kl(Normal(mean_l, logvar_l),
                                 Normal(local_l_mean,
                                        torch.sqrt(local_l_var))).sum(dim=1)
        else:
            kl_divergence_l = torch.tensor(0.0)

        if X2 is not None:
            if self.Type == 'ZINB':
                mean_l2 = result["mean_l2"]
                logvar_l2 = result["library2"]
                kl_divergence_l2 = kl(
                    Normal(mean_l2, logvar_l2),
                    Normal(local_l_mean1, torch.sqrt(local_l_var1))).sum(dim=1)
            else:
                kl_divergence_l2 = torch.tensor(0.0)
        else:
            kl_divergence_l2 = torch.tensor(0.0)

        mean_z = result["mean_z"]
        logvar_z = result["logvar_z"]
        latent_z = result["latent_z"]

        if self.penality == "GMM":
            gamma, mu_c, var_c, pi = self.get_gamma(
                latent_z)  #, self.n_centroids, c_params)
            kl_divergence_z = GMM_loss(gamma, (mu_c, var_c, pi),
                                       (mean_z, logvar_z))

        else:
            mean = torch.zeros_like(mean_z)
            scale = torch.ones_like(logvar_z)
            kl_divergence_z = kl(Normal(mean_z, logvar_z),
                                 Normal(mean, scale)).sum(dim=1)

        loss1, loss2 = get_both_recon_loss(X1, recon_x1, disper_x,
                                           dropout_rate, X2, recon_x_2,
                                           disper_x2, dropout_rate_2, "ZINB",
                                           self.Type)

        return loss1, loss2, kl_divergence_l, kl_divergence_l2, kl_divergence_z
Пример #9
0
    def loss(
        self,
        tensors,
        inference_outputs,
        generative_outputs,
        kl_weight: int = 1.0,
        n_obs: int = 1.0,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        # Parameters for z latent distribution
        qz_m = inference_outputs["qz_m"]
        qz_v = inference_outputs["qz_v"]
        px_rate = generative_outputs["px_rate"]
        px_r = generative_outputs["px_r"]
        px_dropout = generative_outputs["px_dropout"]
        bernoulli_params = generative_outputs["bernoulli_params"]
        x = tensors[_CONSTANTS.X_KEY]
        batch_index = tensors[_CONSTANTS.BATCH_KEY]

        # KL divergences wrt z_n,l_n
        mean = torch.zeros_like(qz_m)
        scale = torch.ones_like(qz_v)

        kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)),
                             Normal(mean, scale)).sum(dim=1)
        if not self.use_observed_lib_size:
            ql_m = inference_outputs["ql_m"]
            ql_v = inference_outputs["ql_v"]
            (
                local_library_log_means,
                local_library_log_vars,
            ) = self._compute_local_library_params(batch_index)

            kl_divergence_l = kl(
                Normal(ql_m, torch.sqrt(ql_v)),
                Normal(local_library_log_means,
                       torch.sqrt(local_library_log_vars)),
            ).sum(dim=1)
        else:
            kl_divergence_l = 0.0

        # KL divergence wrt Bernoulli parameters
        kl_divergence_bernoulli = self.compute_global_kl_divergence()

        # Reconstruction loss
        reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r,
                                                    px_dropout,
                                                    bernoulli_params)

        kl_global = kl_divergence_bernoulli
        kl_local_for_warmup = kl_divergence_z
        kl_local_no_warmup = kl_divergence_l

        weighted_kl_local = kl_weight * kl_local_for_warmup + kl_local_no_warmup
        loss = n_obs * torch.mean(reconst_loss + weighted_kl_local) + kl_global
        kl_local = dict(kl_divergence_l=kl_divergence_l,
                        kl_divergence_z=kl_divergence_z)
        return LossRecorder(loss, reconst_loss, kl_local, kl_global)
Пример #10
0
    def forward(
        self,
        x: torch.Tensor,
        y: torch.Tensor,
        local_l_mean_gene: torch.Tensor,
        local_l_var_gene: torch.Tensor,
        batch_index: Optional[torch.Tensor] = None,
        label: Optional[torch.Tensor] = None,
    ):
        r""" Returns the reconstruction loss and the Kullback divergences

        :param x: tensor of values with shape (batch_size, n_input_genes)
        :param y: tensor of values with shape (batch_size, n_input_proteins)
        :param local_l_mean_gene: tensor of means of the prior distribution of latent variable l
         with shape (batch_size, 1)
        :param local_l_var_gene: tensor of variancess of the prior distribution of latent variable l
         with shape (batch_size, 1)
        :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size``
        :param label: tensor of cell-types labels with shape (batch_size, n_labels)
        :return: the reconstruction loss and the Kullback divergences
        :rtype: 4-tuple of :py:class:`torch.FloatTensor`
        """
        # Parameters for z latent distribution

        outputs = self.inference(x, y, batch_index, label)
        qz_m = outputs["qz_m"]
        qz_v = outputs["qz_v"]
        ql_m = outputs["ql_m"]
        ql_v = outputs["ql_v"]
        px_ = outputs["px_"]
        py_ = outputs["py_"]

        reconst_loss_gene, reconst_loss_protein = self.get_reconstruction_loss(
            x, y, px_, py_
        )

        # KL Divergence
        kl_div_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(0, 1)).sum(dim=1)
        kl_div_l_gene = kl(
            Normal(ql_m, torch.sqrt(ql_v)),
            Normal(local_l_mean_gene, torch.sqrt(local_l_var_gene)),
        ).sum(dim=1)

        kl_div_back_pro = kl(
            Normal(py_["back_alpha"], py_["back_beta"]), self.back_mean_prior
        ).sum(dim=-1)

        return (
            reconst_loss_gene,
            reconst_loss_protein,
            kl_div_z,
            kl_div_l_gene,
            kl_div_back_pro,
        )
Пример #11
0
    def forward(
        self,
        x: torch.Tensor,
        local_l_mean: torch.Tensor,
        local_l_var: torch.Tensor,
        batch_index: Optional[torch.Tensor] = None,
        y: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        r""" Returns the reconstruction loss and the Kullback divergences

        :param x: tensor of values with shape (batch_size, n_input)
        :param local_l_mean: tensor of means of the prior distribution of latent variable l
         with shape (batch_size, 1)
        :param local_l_var: tensor of variancess of the prior distribution of latent variable l
         with shape (batch_size, 1)
        :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size``
        :param y: tensor of cell-types labels with shape (batch_size, n_labels)
        :return: the reconstruction loss and the Kullback divergences
        :rtype: 2-tuple of :py:class:`torch.FloatTensor`
        """
        # Parameters for z latent distribution
        outputs = self.inference(x, batch_index, y)
        qz_m = outputs["qz_m"]
        qz_v = outputs["qz_v"]
        ql_m = outputs["ql_m"]
        ql_v = outputs["ql_v"]
        px_rate = outputs["px_rate"]
        px_r = outputs["px_r"]
        px_dropout = outputs["px_dropout"]
        bernoulli_params = outputs["bernoulli_params"]

        # KL divergences wrt z_n,l_n
        mean = torch.zeros_like(qz_m)
        scale = torch.ones_like(qz_v)

        kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(
            dim=1
        )
        kl_divergence_l = kl(
            Normal(ql_m, torch.sqrt(ql_v)),
            Normal(local_l_mean, torch.sqrt(local_l_var)),
        ).sum(dim=1)

        # KL divergence wrt Bernoulli parameters
        kl_divergence_bernoulli = self.compute_global_kl_divergence()

        # Reconstruction loss
        reconst_loss = self.get_reconstruction_loss(
            x, px_rate, px_r, px_dropout, bernoulli_params
        )

        return reconst_loss + kl_divergence_l, kl_divergence_z, kl_divergence_bernoulli
Пример #12
0
    def loss(
        self,
        tensors,
        inference_outputs,
        generative_outputs,
        kl_weight: float = 1.0,
    ):
        x = tensors[REGISTRY_KEYS.X_KEY]
        batch_index = tensors[REGISTRY_KEYS.BATCH_KEY]

        qz_m = inference_outputs["qz_m"]
        qz_v = inference_outputs["qz_v"]
        px_rate = generative_outputs["px_rate"]
        px_r = generative_outputs["px_r"]
        px_dropout = generative_outputs["px_dropout"]

        mean = torch.zeros_like(qz_m)
        scale = torch.ones_like(qz_v)

        kl_divergence_z = kl(Normal(qz_m, qz_v.sqrt()), Normal(mean, scale)).sum(dim=1)

        if not self.use_observed_lib_size:
            ql_m = inference_outputs["ql_m"]
            ql_v = inference_outputs["ql_v"]
            (
                local_library_log_means,
                local_library_log_vars,
            ) = self._compute_local_library_params(batch_index)

            kl_divergence_l = kl(
                Normal(ql_m, ql_v.sqrt()),
                Normal(local_library_log_means, local_library_log_vars.sqrt()),
            ).sum(dim=1)
        else:
            kl_divergence_l = 0.0

        reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r, px_dropout)

        kl_local_for_warmup = kl_divergence_z
        kl_local_no_warmup = kl_divergence_l

        weighted_kl_local = kl_weight * kl_local_for_warmup + kl_local_no_warmup

        loss = torch.mean(reconst_loss + weighted_kl_local)

        kl_local = dict(
            kl_divergence_l=kl_divergence_l, kl_divergence_z=kl_divergence_z
        )
        kl_global = torch.tensor(0.0)
        return LossRecorder(loss, reconst_loss, kl_local, kl_global)
Пример #13
0
    def loss(
        self,
        tensors,
        inference_outputs,
        generative_outputs,
        kl_weight: float = 1.0,
    ):
        kl_weight = self.kl_factor * kl_weight
        x = tensors[_CONSTANTS.X_KEY]
        local_l_mean = tensors[_CONSTANTS.LOCAL_L_MEAN_KEY]
        local_l_var = tensors[_CONSTANTS.LOCAL_L_VAR_KEY]

        qz_m = inference_outputs["qz_m"]
        qz_v = inference_outputs["qz_v"]
        ql_m = inference_outputs["ql_m"]
        ql_v = inference_outputs["ql_v"]
        px_rate = generative_outputs["px_rate"]
        px_r = generative_outputs["px_r"]
        px_dropout = generative_outputs["px_dropout"]

        mean = torch.zeros_like(qz_m)
        scale = torch.ones_like(qz_v)

        kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(
            dim=1
        )

        kl_divergence_l = kl(
            Normal(ql_m, torch.sqrt(ql_v)),
            Normal(local_l_mean, torch.sqrt(local_l_var)),
        ).sum(dim=1)

        reconst_loss = (
            -ZeroInflatedNegativeBinomial(mu=px_rate, theta=px_r, zi_logits=px_dropout)
            .log_prob(x)
            .sum(dim=-1)
        )

        kl_local_for_warmup = kl_divergence_z
        kl_local_no_warmup = kl_divergence_l

        weighted_kl_local = kl_weight * kl_local_for_warmup + kl_local_no_warmup

        loss = torch.mean(reconst_loss + weighted_kl_local)

        kl_local = dict(
            kl_divergence_l=kl_divergence_l, kl_divergence_z=kl_divergence_z
        )
        kl_global = 0.0
        return LossRecorder(loss, reconst_loss, kl_local, kl_global)
Пример #14
0
    def forward(self,
                x,
                local_l_mean,
                local_l_var,
                batch_index=None,
                y=None):  # same signature as loss
        # Parameters for z latent distribution
        x_ = x
        if self.log_variational:
            x_ = torch.log(1 + x_)

        # Sampling
        qz_m, qz_v, z = self.z_encoder(x_)
        ql_m, ql_v, library = self.l_encoder(x_)

        if self.dispersion == "gene-cell":
            px_scale, self.px_r, px_rate, px_dropout = self.decoder(
                self.dispersion, z, library, batch_index)
        else:  # self.dispersion == "gene", "gene-batch",  "gene-label"
            px_scale, px_rate, px_dropout = self.decoder(
                self.dispersion, z, library, batch_index)

        if self.dispersion == "gene-label":
            px_r = F.linear(
                one_hot(y, self.n_labels),
                self.px_r)  # px_r gets transposed - last dimension is nb genes
        elif self.dispersion == "gene-batch":
            px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r)
        else:
            px_r = self.px_r

        # Reconstruction Loss
        if self.reconstruction_loss == 'zinb':
            reconst_loss = -log_zinb_positive(x, px_rate, torch.exp(px_r),
                                              px_dropout)
        elif self.reconstruction_loss == 'nb':
            reconst_loss = -log_nb_positive(x, px_rate, torch.exp(px_r))

        # KL Divergence
        mean = torch.zeros_like(qz_m)
        scale = torch.ones_like(qz_v)

        kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)),
                             Normal(mean, scale)).sum(dim=1)
        kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)),
                             Normal(local_l_mean,
                                    torch.sqrt(local_l_var))).sum(dim=1)
        kl_divergence = kl_divergence_z + kl_divergence_l

        return reconst_loss, kl_divergence
Пример #15
0
    def loss(
        self,
        tensors,
        inference_outputs,
        generative_outputs,
        kl_weight: int = 1.0,
        n_obs: int = 1.0,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        # Parameters for z latent distribution
        qz_m = inference_outputs["qz_m"]
        qz_v = inference_outputs["qz_v"]
        ql_m = inference_outputs["ql_m"]
        ql_v = inference_outputs["ql_v"]
        px_rate = generative_outputs["px_rate"]
        px_r = generative_outputs["px_r"]
        px_dropout = generative_outputs["px_dropout"]
        bernoulli_params = generative_outputs["bernoulli_params"]
        x = tensors[_CONSTANTS.X_KEY]
        local_l_mean = tensors[_CONSTANTS.LOCAL_L_MEAN_KEY]
        local_l_var = tensors[_CONSTANTS.LOCAL_L_VAR_KEY]

        # KL divergences wrt z_n,l_n
        mean = torch.zeros_like(qz_m)
        scale = torch.ones_like(qz_v)

        kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)),
                             Normal(mean, scale)).sum(dim=1)
        kl_divergence_l = kl(
            Normal(ql_m, torch.sqrt(ql_v)),
            Normal(local_l_mean, torch.sqrt(local_l_var)),
        ).sum(dim=1)

        # KL divergence wrt Bernoulli parameters
        kl_divergence_bernoulli = self.compute_global_kl_divergence()

        # Reconstruction loss
        reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r,
                                                    px_dropout,
                                                    bernoulli_params)

        kl_global = kl_divergence_bernoulli
        kl_local_for_warmup = kl_divergence_l
        kl_local_no_warmup = kl_divergence_z

        weighted_kl_local = kl_weight * kl_local_for_warmup + kl_local_no_warmup
        loss = n_obs * torch.mean(reconst_loss + weighted_kl_local) + kl_global
        kl_local = dict(kl_divergence_l=kl_divergence_l,
                        kl_divergence_z=kl_divergence_z)
        return SCVILoss(loss, reconst_loss, kl_local, kl_global)
Пример #16
0
    def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None):
        is_labelled = False if y is None else True

        # Prepare for sampling
        x_ = torch.log(1 + x)
        ql_m, ql_v, library = self.l_encoder(x_)

        # Enumerate choices of label
        ys, xs, library_s, batch_index_s = broadcast_labels(
            y, x, library, batch_index, n_broadcast=self.n_labels
        )

        # Sampling
        outputs = self.inference(xs, batch_index_s, ys)
        px_r = outputs["px_r"]
        px_rate = outputs["px_rate"]
        px_dropout = outputs["px_dropout"]
        qz_m = outputs["qz_m"]
        qz_v = outputs["qz_v"]
        reconst_loss = self.get_reconstruction_loss(xs, px_rate, px_r, px_dropout)

        # KL Divergence
        mean = torch.zeros_like(qz_m)
        scale = torch.ones_like(qz_v)

        kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(
            dim=1
        )
        kl_divergence_l = kl(
            Normal(ql_m, torch.sqrt(ql_v)),
            Normal(local_l_mean, torch.sqrt(local_l_var)),
        ).sum(dim=1)

        if is_labelled:
            return reconst_loss, kl_divergence_z + kl_divergence_l, 0.0

        reconst_loss = reconst_loss.view(self.n_labels, -1)

        probs = self.classifier(x_)
        reconst_loss = (reconst_loss.t() * probs).sum(dim=1)

        kl_divergence = (kl_divergence_z.view(self.n_labels, -1).t() * probs).sum(dim=1)
        kl_divergence += kl(
            Categorical(probs=probs),
            Categorical(probs=self.y_prior.repeat(probs.size(0), 1)),
        )
        kl_divergence += kl_divergence_l

        return reconst_loss, kl_divergence, 0.0
Пример #17
0
    def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None):
        is_labelled = False if y is None else True

        # Prepare for sampling
        x_ = torch.log(1 + x)
        ql_m, ql_v, library = self.l_encoder(x_)

        # Enumerate choices of label
        ys, xs, library_s, batch_index_s = (broadcast_labels(
            y, x, library, batch_index, n_broadcast=self.n_labels))

        if self.log_variational:
            xs_ = torch.log(1 + xs)

        # Sampling
        qz_m, qz_v, zs = self.z_encoder(xs_, ys)

        px_scale, px_r, px_rate, px_dropout = self.decoder(
            self.dispersion, zs, library_s, batch_index_s, ys)

        reconst_loss = self._reconstruction_loss(xs, px_rate, px_r, px_dropout,
                                                 batch_index_s, ys)

        # KL Divergence
        mean = torch.zeros_like(qz_m)
        scale = torch.ones_like(qz_v)

        kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)),
                             Normal(mean, scale)).sum(dim=1)
        kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)),
                             Normal(local_l_mean,
                                    torch.sqrt(local_l_var))).sum(dim=1)

        if is_labelled:
            return reconst_loss, kl_divergence_z + kl_divergence_l

        reconst_loss = reconst_loss.view(self.n_labels, -1)

        probs = self.classifier(x_)
        reconst_loss = (reconst_loss.t() * probs).sum(dim=1)

        kl_divergence = (kl_divergence_z.view(self.n_labels, -1).t() *
                         probs).sum(dim=1)
        kl_divergence += kl(
            Categorical(probs=probs),
            Categorical(probs=self.y_prior.repeat(probs.size(0), 1)))
        kl_divergence += kl_divergence_l

        return reconst_loss, kl_divergence
Пример #18
0
    def loss(
        self,
        tensors,
        inference_outputs,
        generative_outputs,
        kl_weight: float = 1.0,
    ):
        x = tensors[_CONSTANTS.X_KEY]
        local_l_mean = tensors[_CONSTANTS.LOCAL_L_MEAN_KEY]
        local_l_var = tensors[_CONSTANTS.LOCAL_L_VAR_KEY]

        qz_m = inference_outputs["qz_m"]
        qz_v = inference_outputs["qz_v"]
        ql_m = inference_outputs["ql_m"]
        ql_v = inference_outputs["ql_v"]
        px_rate = generative_outputs["px_rate"]
        px_r = generative_outputs["px_r"]
        px_dropout = generative_outputs["px_dropout"]

        mean = torch.zeros_like(qz_m)
        scale = torch.ones_like(qz_v)

        kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(
            dim=1
        )

        if not self.use_observed_lib_size:
            kl_divergence_l = kl(
                Normal(ql_m, torch.sqrt(ql_v)),
                Normal(local_l_mean, torch.sqrt(local_l_var)),
            ).sum(dim=1)
        else:
            kl_divergence_l = 0.0

        reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r, px_dropout)

        kl_local_for_warmup = kl_divergence_l
        kl_local_no_warmup = kl_divergence_z

        weighted_kl_local = kl_weight * kl_local_for_warmup + kl_local_no_warmup

        loss = torch.mean(reconst_loss + weighted_kl_local)

        kl_local = dict(
            kl_divergence_l=kl_divergence_l, kl_divergence_z=kl_divergence_z
        )
        kl_global = 0.0
        return SCVILoss(loss, reconst_loss, kl_local, kl_global)
Пример #19
0
    def forward(self,
                x,
                local_l_mean,
                local_l_var,
                batch_index=None,
                y=None,
                mode="scRNA",
                weighting=1):
        r""" Returns the reconstruction loss and the Kullback divergences

        :param x: tensor of values with shape ``(batch_size, n_input)``
            or ``(batch_size, n_input_fish)`` depending on the mode
        :param local_l_mean: tensor of means of the prior distribution of latent variable l
            with shape (batch_size, 1)
        :param local_l_var: tensor of variances of the prior distribution of latent variable l
            with shape (batch_size, 1)
        :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size``
        :param y: tensor of cell-types labels with shape (batch_size, n_labels)
        :param mode: string that indicates the type of data we analyse
        :param weighting: used in none of these methods
        :return: the reconstruction loss and the Kullback divergences
        :rtype: 2-tuple of :py:class:`torch.FloatTensor`
        """

        # Parameters for z latent distribution
        px_scale, px_r, px_rate, px_dropout, qz_m, qz_v, z, ql_m, ql_v, library = self.inference(
            x, batch_index, y, mode, weighting)

        # Reconstruction Loss
        reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r,
                                                    px_dropout, mode,
                                                    weighting)

        # KL Divergence
        mean = torch.zeros_like(qz_m)
        scale = torch.ones_like(qz_v)

        kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)),
                             Normal(mean, scale)).sum(dim=1)
        if self.model_library:
            kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)),
                                 Normal(local_l_mean,
                                        torch.sqrt(local_l_var))).sum(dim=1)
            kl_divergence = kl_divergence_z + kl_divergence_l
        else:
            kl_divergence = kl_divergence_z

        return reconst_loss, kl_divergence
Пример #20
0
    def loss(
        self,
        tensors,
        inference_outputs,
        generative_outputs,
        kl_weight: float = 1.0,
    ):
        x = tensors[REGISTRY_KEYS.X_KEY]
        y = tensors[REGISTRY_KEYS.LABELS_KEY]
        qz_m = inference_outputs["qz_m"]
        qz_v = inference_outputs["qz_v"]
        px_rate = generative_outputs["px_rate"]
        px_r = generative_outputs["px_r"]

        mean = torch.zeros_like(qz_m)
        scale = torch.ones_like(qz_v)

        kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)),
                             Normal(mean, scale)).sum(dim=1)

        reconst_loss = -NegativeBinomial(px_rate,
                                         logits=px_r).log_prob(x).sum(-1)
        scaling_factor = self.ct_weight[y.long()[:, 0]]
        loss = torch.mean(scaling_factor *
                          (reconst_loss + kl_weight * kl_divergence_z))

        return LossRecorder(loss, reconst_loss, kl_divergence_z,
                            torch.tensor(0.0))
Пример #21
0
    def loss(
        self,
        tensors,
        inference_outputs,
        generative_outputs,
    ):
        x = tensors[_CONSTANTS.X_KEY]

        qz_m = inference_outputs["qz_m"]
        qz_v = inference_outputs["qz_v"]
        px_logit = generative_outputs["px_logit"]

        mean = torch.zeros_like(qz_m)
        scale = torch.ones_like(qz_v)

        kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(
            dim=1
        )

        reconst_loss = (
            -Bernoulli(logits=px_logit)
            .log_prob(x)
            .sum(dim=-1)
        )

        loss = torch.mean(reconst_loss + kl_divergence_z)

        kl_local = dict(
            kl_divergence_z=kl_divergence_z
        )
        kl_global = 0.0
        return LossRecorder(loss, reconst_loss, kl_local, kl_global)
Пример #22
0
 def compute_global_kl_divergence(self) -> torch.Tensor:
     outputs = self.get_alphas_betas(as_numpy=False)
     alpha_posterior = outputs["alpha_posterior"]
     beta_posterior = outputs["beta_posterior"]
     alpha_prior = outputs["alpha_prior"]
     beta_prior = outputs["beta_prior"]
     return kl(Beta(alpha_posterior, beta_posterior),
               Beta(alpha_prior, beta_prior)).sum()
Пример #23
0
    def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None):
        is_labelled = False if y is None else True

        x_ = torch.log(1 + x)
        qz1_m, qz1_v, z1 = self.z_encoder(x_)
        ql_m, ql_v, library = self.l_encoder(x_)

        # Enumerate choices of label
        ys, z1s = (broadcast_labels(y, z1, n_broadcast=self.n_labels))
        qz2_m, qz2_v, z2 = self.encoder_z2_z1(z1s, ys)
        pz1_m, pz1_v = self.decoder_z1_z2(z2, ys)
        px_scale, px_r, px_rate, px_dropout = self.decoder(
            self.dispersion, z1, library, batch_index)

        reconst_loss = self._reconstruction_loss(x, px_rate, px_r, px_dropout,
                                                 batch_index, y)

        # KL Divergence
        mean = torch.zeros_like(qz2_m)
        scale = torch.ones_like(qz2_v)

        kl_divergence_z2 = kl(Normal(qz2_m, torch.sqrt(qz2_v)),
                              Normal(mean, scale)).sum(dim=1)
        loss_z1_unweight = -Normal(pz1_m,
                                   torch.sqrt(pz1_v)).log_prob(z1s).sum(dim=-1)
        loss_z1_weight = Normal(qz1_m,
                                torch.sqrt(qz1_v)).log_prob(z1).sum(dim=-1)
        kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)),
                             Normal(local_l_mean,
                                    torch.sqrt(local_l_var))).sum(dim=1)

        if is_labelled:
            return reconst_loss + loss_z1_weight + loss_z1_unweight, kl_divergence_z2 + kl_divergence_l

        probs = self.classifier(z1)
        reconst_loss += (loss_z1_weight + (
            (loss_z1_unweight).view(self.n_labels, -1).t() * probs).sum(dim=1))

        kl_divergence = (kl_divergence_z2.view(self.n_labels, -1).t() *
                         probs).sum(dim=1)
        kl_divergence += kl(
            Categorical(probs=probs),
            Categorical(probs=self.y_prior.repeat(probs.size(0), 1)))
        kl_divergence += kl_divergence_l

        return reconst_loss, kl_divergence
Пример #24
0
 def kl(self):
     w_mus = [weight_mu.view([-1]) for weight_mu in self.weight_mus]
     b_mus = [bias_mu.view([-1]) for bias_mu in self.bias_mus]
     mus = torch.cat(w_mus+b_mus)
     w_logsigs = [weight_logsig.view([-1]) for weight_logsig in self.weight_logsigs]
     b_logsigs = [bias_logsigs.view([-1]) for bias_logsigs in self.bias_logsigs]
     sigs = torch.cat(w_logsigs+b_logsigs).exp()
     q = Normal(mus, sigs)
     N = Normal(torch.zeros(len(mus), device=mus.device), torch.ones(len(mus), device=mus.device))
     return kl(q, N)
Пример #25
0
    def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None):
        r""" Returns the reconstruction loss and the Kullback divergences

        :param x: tensor of values with shape (batch_size, n_input)
        :param local_l_mean: tensor of means of the prior distribution of latent variable l
         with shape (batch_size, 1)
        :param local_l_var: tensor of variancess of the prior distribution of latent variable l
         with shape (batch_size, 1)
        :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size``
        :param y: tensor of cell-types labels with shape (batch_size, n_labels)
        :return: the reconstruction loss and the Kullback divergences
        :rtype: 2-tuple of :py:class:`torch.FloatTensor`
        """
        # Parameters for z latent distribution
        outputs = self.inference(x, batch_index, None)
        qz_m = outputs["qz_m"]
        qz_v = outputs["qz_v"]
        ql_m = outputs["ql_m"]
        ql_v = outputs["ql_v"]
        px_rate = outputs["px_rate"]
        px_r = outputs["px_r"]
        px_dropout = outputs["px_dropout"]

        # KL Divergence
        mean = torch.zeros_like(qz_m)
        scale = torch.ones_like(qz_v)

        kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)),
                             Normal(mean, scale)).sum(dim=1)
        kl_divergence_l = kl(
            Normal(ql_m, torch.sqrt(ql_v)),
            Normal(local_l_mean, torch.sqrt(local_l_var)),
        ).sum(dim=1)
        kl_divergence = kl_divergence_z

        reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r,
                                                    px_dropout)

        if self.reconstruction_loss == "mse" or self.reconstruction_loss == "nb":
            kl_divergence_l = 1.0
        print("reconst_loss=%f, kl_divergence=%f" %
              (torch.mean(reconst_loss), torch.mean(kl_divergence)))
        return reconst_loss + kl_divergence_l, kl_divergence, 0.0
Пример #26
0
    def forward(self, X, local_l_mean=None, local_l_var=None):

        result = self.inference(X)

        latent_z_mu = result["latent_z_mu"]
        latent_z_logvar = result["latent_z_logvar"]
        latent_z = result["latent_z"]

        latent_l_mu = result["latent_l_mu"]
        latent_l_logvar = result["latent_l_logvar"]

        imputation = result["imputation"]
        disperation = result["disperation"]
        dropoutrate = result["dropoutrate"]

        # KL Divergence for library factor
        if local_l_mean is not None:
            kl_divergence_l = kl(Normal(latent_l_mu, latent_l_logvar),
                                 Normal(local_l_mean,
                                        torch.sqrt(local_l_var))).sum(dim=1)
        else:
            kl_divergence_l = torch.tensor(0.0)

        # KL Divergence for latent code
        if self.penality == "GMM":
            gamma, mu_c, var_c, pi = self.get_gamma(
                latent_z)  #, self.n_centroids, c_params)
            kl_divergence_z = GMM_loss(gamma, (mu_c, var_c, pi),
                                       (latent_z_mu, latent_z_logvar))

        else:
            mean = torch.zeros_like(latent_z_mu)
            scale = torch.ones_like(latent_z_logvar)
            kl_divergence_z = kl(Normal(latent_z_mu, latent_z_logvar),
                                 Normal(mean, scale)).sum(dim=1)

        reconst_loss = self.get_reconstruction_loss(X, imputation, disperation,
                                                    dropoutrate)

        return reconst_loss, kl_divergence_l, kl_divergence_z
Пример #27
0
    def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None):
        r""" Returns the reconstruction loss and the Kullback divergences

        :param x: tensor of values with shape (batch_size, n_input)
        :param local_l_mean: tensor of means of the prior distribution of latent variable l
         with shape (batch_size, 1)
        :param local_l_var: tensor of variancess of the prior distribution of latent variable l
         with shape (batch_size, 1)
        :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size``
        :param y: tensor of cell-types labels with shape (batch_size, n_labels)
        :return: the reconstruction loss and the Kullback divergences
        :rtype: 2-tuple of :py:class:`torch.FloatTensor`
        """
        # Parameters for z latent distribution
        outputs = self.inference(x, batch_index, y)
        qz_m = outputs['qz_m']
        qz_v = outputs['qz_v']
        ql_m = outputs['ql_m']
        ql_v = outputs['ql_v']
        px_rate = outputs['px_rate']
        px_r = outputs['px_r']
        px_dropout = outputs['px_dropout']

        # KL Divergence
        mean, scale = self.get_prior_params(device=qz_m.device)

        kl_divergence_z = kl(self.z_encoder.distrib(qz_m, qz_v),
                             self.z_encoder.distrib(mean, scale))
        if len(kl_divergence_z.size()) == 2:
            kl_divergence_z = kl_divergence_z.sum(dim=1)
        kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var))).sum(dim=1)
        kl_divergence = kl_divergence_z

        reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r, px_dropout)

        return reconst_loss + kl_divergence_l, kl_divergence
Пример #28
0
    def forward(self, x, y=None):
        is_labelled = False if y is None else True

        qz1_m, qz1_v, z1 = self.z_encoder(x)

        # Enumerate choices of label
        ys, z1s = (
            broadcast_labels(
                y, z1, n_broadcast=self.n_labels
            )
        )
        qz2_m, qz2_v, z2 = self.encoder_z2_z1(z1s, ys)
        pz1_m, pz1_v = self.decoder_z1_z2(z2, ys)
        qx_m, qx_v = self.decoder(z1)

        reconst_loss = self._reconstruction_loss(x, qx_m, qx_v)

        # KL Divergence
        mean = torch.zeros_like(qz2_m)
        scale = torch.ones_like(qz2_v)

        kl_divergence_z2 = kl(Normal(qz2_m, torch.sqrt(qz2_v)), Normal(mean, scale)).sum(dim=1)
        loss_z1_unweight = - Normal(pz1_m, torch.sqrt(pz1_v)).log_prob(z1s).sum(dim=-1)
        loss_z1_weight = Normal(qz1_m, torch.sqrt(qz1_v)).log_prob(z1).sum(dim=-1)

        if is_labelled:
            return reconst_loss + loss_z1_weight + loss_z1_unweight, kl_divergence_z2

        probs = self.classifier(z1)
        reconst_loss += (loss_z1_weight + ((loss_z1_unweight).view(self.n_labels, -1).t() * probs).sum(dim=1))

        kl_divergence = (kl_divergence_z2.view(self.n_labels, -1).t() * probs).sum(dim=1)
        kl_divergence += kl(Categorical(probs=probs),
                            Categorical(probs=self.y_prior.repeat(probs.size(0), 1)))

        return reconst_loss, kl_divergence
    def forward(self, x):
        r""" Returns the reconstruction loss

		:param x: tensor of values with shape (batch_size, n_input)

		:return: the reconstruction loss and the Kullback divergences
		:rtype: 2-tuple of :py:class:`torch.FloatTensor`
		"""
        # Parameters for z latent distribution
        outputs = self.inference(x)
        qz_m = outputs["qz_m"]
        qz_v = outputs["qz_v"]
        px_rate = outputs["px_rate"]
        px_r = outputs["px_r"]
        z = outputs["z"]
        library = outputs["library"]

        self.encoder_variance.append(
            np.linalg.norm(qz_v.detach().cpu().numpy(), axis=1))

        if self.use_MP:
            # Message passing likelihood
            self.initialize_visit()
            self.initialize_messages(z, self.barcodes, self.n_latent)
            self.perform_message_passing((self.tree & self.root), z.shape[1],
                                         False)
            mp_lik = self.aggregate_messages_into_leaves_likelihood(
                z.shape[1], add_prior=True)
            # Gaussian variational likelihood
            qz = Normal(qz_m, torch.sqrt(qz_v)).log_prob(z).sum(dim=-1)
        else:
            mp_lik = None
            # scVI Kl Divergence
            mean = torch.zeros_like(qz_m)
            scale = torch.ones_like(qz_v)
            qz = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean,
                                                           scale)).sum(dim=1)

        # Reconstruction Loss
        if self.reconstruction_loss == "nb":
            reconst_loss = (-NegativeBinomial(
                mu=px_rate, theta=px_r).log_prob(x).sum(dim=-1))
        elif self.reconstruction_loss == "poisson":
            reconst_loss = -Poisson(px_rate).log_prob(x).sum(dim=-1)

        return reconst_loss, qz, mp_lik
Пример #30
0
    def loss(
        self,
        tensors,
        inference_outputs,
        generative_outputs,
    ):

        x = tensors[REGISTRY_KEYS.X_KEY]
        qz_m = inference_outputs["qz_m"]
        qz_v = inference_outputs["qz_v"]
        p = generative_outputs["px"]

        kld = kl(
            Normal(qz_m, torch.sqrt(qz_v)),
            Normal(0, 1),
        ).sum(dim=1)
        rl = self.get_reconstruction_loss(p, x)
        loss = (0.5 * rl + 0.5 * (kld * self.kl_weight)).mean()
        return LossRecorder(loss, rl, kld)