Exemplo n.º 1
0
    def get_reconstruction_loss(
        self,
        x: torch.Tensor,
        y: torch.Tensor,
        px_: Dict[str, torch.Tensor],
        py_: Dict[str, torch.Tensor],
        pro_batch_mask_minibatch: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute reconstruction loss
        """
        # Reconstruction Loss
        if self.reconstruction_loss_gene == "zinb":
            reconst_loss_gene = -log_zinb_positive(x, px_["rate"], px_["r"],
                                                   px_["dropout"]).sum(dim=-1)
        else:
            reconst_loss_gene = -log_nb_positive(x, px_["rate"],
                                                 px_["r"]).sum(dim=-1)

        reconst_loss_protein_full = -log_mixture_nb(
            y, py_["rate_back"], py_["rate_fore"], py_["r"], None,
            py_["mixing"])
        if pro_batch_mask_minibatch is not None:
            temp_pro_loss_full = torch.zeros_like(reconst_loss_protein_full)
            temp_pro_loss_full.masked_scatter_(pro_batch_mask_minibatch.bool(),
                                               reconst_loss_protein_full)

            reconst_loss_protein = temp_pro_loss_full.sum(dim=-1)
        else:
            reconst_loss_protein = reconst_loss_protein_full.sum(dim=-1)

        return reconst_loss_gene, reconst_loss_protein
Exemplo n.º 2
0
    def get_reconstruction_loss(
        self,
        x: torch.Tensor,
        y: torch.Tensor,
        px_: Dict[str, torch.Tensor],
        py_: Dict[str, torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Reconstruction Loss
        if self.reconstruction_loss_gene == "zinb":
            reconst_loss_gene = -log_zinb_positive(
                x, px_["rate"], px_["r"], px_["dropout"]
            ).sum(dim=-1)
        else:
            reconst_loss_gene = -log_nb_positive(x, px_["rate"], px_["r"]).sum(dim=-1)

        reconst_loss_protein = -log_mixture_nb(
            y, py_["rate_back"], py_["rate_fore"], py_["r"], None, py_["mixing"]
        ).sum(dim=-1)

        return reconst_loss_gene, reconst_loss_protein