def generate_joint(self,
                       x,
                       local_l_mean,
                       local_l_var,
                       batch_index,
                       y=None,
                       zero_inflated=True):
        """
        :param x: used only for shape match
        """
        n_batches, _ = x.shape
        device = "cuda" if torch.cuda.is_available() else "cpu"
        z_mean = torch.zeros(n_batches, self.n_latent, device=device)
        z_std = torch.zeros(n_batches, self.n_latent, device=device)
        z_prior_dist = Normal(z_mean, z_std)
        z_sim = z_prior_dist.sample()

        l_prior_dist = Normal(local_l_mean, torch.sqrt(local_l_var))
        l_sim = l_prior_dist.sample()

        # Decoder pass
        px_scale, px_r, px_rate, px_dropout = self.decoder(
            self.dispersion, z_sim, l_sim, batch_index, y)

        if self.dispersion == "gene-label":
            px_r = F.linear(
                one_hot(y, self.n_labels),
                self.px_r)  # px_r gets transposed - last dimension is nb genes
        elif self.dispersion == "gene-batch":
            px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r)
        elif self.dispersion == "gene":
            px_r = self.px_r
        px_r = torch.exp(px_r)

        # Data generation
        p = px_rate / (px_rate + px_r)
        r = px_r
        # Important remark: Gamma is parametrized by the rate = 1/scale!
        l_train = Gamma(concentration=r, rate=(1 - p) / p).sample()

        # Clamping as distributions objects can have buggy behaviors when
        # their parameters are too high
        l_train = torch.clamp(l_train, max=1e8)
        gene_expressions = Poisson(
            l_train).sample()  # Shape : (n_samples, n_cells_batch, n_genes)
        if zero_inflated:
            p_zero = (1.0 + torch.exp(-px_dropout)).pow(-1)
            random_prob = torch.rand_like(p_zero)
            gene_expressions[random_prob <= p_zero] = 0

        return gene_expressions, z_sim, l_sim
Exemplo n.º 2
0
    def forward(self, x: torch.Tensor, *cat_list: int):
        r"""Forward computation on ``x``.

        :param x: tensor of values with shape ``(n_in,)``
        :param cat_list: list of category membership(s) for this sample
        :return: tensor of shape ``(n_out,)``
        :rtype: :py:class:`torch.Tensor`
        """
        one_hot_cat_list = [
        ]  # for generality in this list many indices useless.
        assert len(self.n_cat_list) <= len(
            cat_list
        ), "nb. categorical args provided doesn't match init. params."
        for n_cat, cat in zip(self.n_cat_list, cat_list):
            assert not (n_cat and cat is None
                        ), "cat not provided while n_cat != 0 in init. params."
            if n_cat > 1:  # n_cat = 1 will be ignored - no additional information
                if cat.size(1) != n_cat:
                    one_hot_cat = one_hot(cat, n_cat)
                else:
                    one_hot_cat = cat  # cat has already been one_hot encoded
                one_hot_cat_list += [one_hot_cat]
        for layers in self.fc_layers:
            for layer in layers:
                if layer is not None:
                    if isinstance(layer, nn.BatchNorm1d):
                        if x.dim() == 3:
                            x = torch.cat([(layer(slice_x)).unsqueeze(0)
                                           for slice_x in x],
                                          dim=0)
                            # shape n_post_samples, n_batch, n_features
                            # x = layer(x.transpose(-1, -2)).transpose(-1, -2)
                        else:
                            x = layer(x)
                    else:
                        if isinstance(layer, nn.Linear):
                            if x.dim() == 3:
                                one_hot_cat_list = [
                                    o.unsqueeze(0).expand(
                                        (x.size(0), o.size(0), o.size(1)))
                                    for o in one_hot_cat_list
                                ]
                            x = torch.cat((x, *one_hot_cat_list), dim=-1)
                        x = layer(x)
        return x
    def log_px_z(self, tensors, z):
        """
            Only works in the specific case where the library is observed and there are no batch indices
        """
        (x, _, _, batch_index, _) = tensors
        library = x.sum(1, keepdim=True)

        px_scale, px_r, px_rate, px_dropout = self.decoder(
            self.dispersion, z, library, batch_index)
        if self.dispersion == "gene-label":
            raise ValueError
        elif self.dispersion == "gene-batch":
            px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r)
        elif self.dispersion == "gene":
            px_r = self.px_r
        px_r = torch.exp(px_r)
        res = (-1) * self._reconstruction_loss(x, px_rate, px_r, px_dropout)
        return res
    def inference(
        self,
        x,
        batch_index=None,
        y=None,
        n_samples=1,
        reparam=True,
        observed_library=None,
        encoder_key: str = "default",
        counts: torch.Tensor = None,
        z_encoder=None,
    ):
        if z_encoder is None:
            z_enc_of_use = self.z_encoder
        else:
            z_enc_of_use = z_encoder
            # print("using evaluation z encoder")
        x_ = x
        if self.log_variational:
            x_ = torch.log(1 + x_)

        # Library sampling
        library_post = self.l_encoder(x_, n_samples=n_samples, reparam=reparam)
        library_variables = dict(
            ql_m=library_post["q_m"],
            ql_v=library_post["q_v"],
            library=library_post["latent"],
        )

        if observed_library is None:
            library = library_variables["library"]
            # raise ValueError
        else:
            library = observed_library

        # Z sampling
        if encoder_key != "defensive":
            z_post = z_enc_of_use[encoder_key](x_,
                                               y,
                                               n_samples=n_samples,
                                               reparam=reparam)
        else:
            z_post = self.z_defensive_sampling(x_,
                                               counts=counts,
                                               z_encoder=z_encoder)

        z_variables = dict(
            qz_m=z_post["q_m"],
            qz_v=z_post["q_v"],
            z=z_post["latent"],
            log_qz_x=z_post["posterior_density"],
        )
        self.debug_ranges.append(
            dict(
                # qz_m=(z_post["q_m"].min().item(), z_post["q_m"].max().item()),
                # qz_v=(z_post["q_v"].min().item(), z_post["q_v"].max().item()),
                z=(z_post["latent"].min().item(),
                   z_post["latent"].max().item()),
                # log_qz_x=(
                #     z_post["posterior_density"].min().item(),
                #     z_post["posterior_density"].max().item(),
                # ),
                # df=(z_post["df"].min().item(), z_post["df"].max().item()),
            ))

        # Decoder pass
        px_scale, px_r, px_rate, px_dropout = self.decoder(
            self.dispersion, z_post["latent"], library, batch_index, y)

        if self.dispersion == "gene-label":
            px_r = F.linear(
                one_hot(y, self.n_labels),
                self.px_r)  # px_r gets transposed - last dimension is nb genes
        elif self.dispersion == "gene-batch":
            px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r)
        elif self.dispersion == "gene":
            px_r = self.px_r
        px_r = torch.exp(px_r)
        decoder_variables = dict(px_scale=px_scale,
                                 px_r=px_r,
                                 px_rate=px_rate,
                                 px_dropout=px_dropout)

        return {**decoder_variables, **library_variables, **z_variables}
Exemplo n.º 5
0
Arquivo: vae.py Projeto: Juan-JV/sbVAE
    def inference(
        self,
        x,
        batch_index=None,
        y=None,
        n_samples=1,
        reparam=True,
        observed_library=None,
        encoder_key: str = "default",
        counts: torch.Tensor = None,
    ):
        x_ = x
        if self.log_variational:
            x_ = torch.log(1 + x_)

        # Library sampling
        library_post = self.l_encoder(x_, n_samples=n_samples, reparam=reparam)
        library_variables = dict(
            ql_m=library_post["q_m"],
            ql_v=library_post["q_v"],
            library=library_post["latent"],
        )

        if observed_library is None:
            library = library_variables["library"]
        else:
            library = observed_library

        # Z sampling
        if encoder_key != "defensive":
            z_post = self.z_encoder[encoder_key](x_,
                                                 y,
                                                 n_samples=n_samples,
                                                 reparam=reparam)
        else:
            z_post = self.z_defensive_sampling(x_, counts=counts)

        if self.do_iaf or encoder_key == "defensive":
            # IAF does not parametrize the means/covariances of the variational posterior
            z_variables = dict(
                qz_m=None,
                qz_v=None,
                z=z_post["latent"],
                log_qz_x=z_post["posterior_density"],
            )
        else:
            z_variables = dict(
                qz_m=z_post["q_m"],
                qz_v=z_post["q_v"],
                z=z_post["latent"],
                log_qz_x=None,
            )

        # Decoder pass
        px_scale, px_r, px_rate, px_dropout = self.decoder(
            self.dispersion, z_post["latent"], library, batch_index, y)

        if self.dispersion == "gene-label":
            px_r = F.linear(
                one_hot(y, self.n_labels),
                self.px_r)  # px_r gets transposed - last dimension is nb genes
        elif self.dispersion == "gene-batch":
            px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r)
        elif self.dispersion == "gene":
            px_r = self.px_r
        px_r = torch.exp(px_r)
        decoder_variables = dict(px_scale=px_scale,
                                 px_r=px_r,
                                 px_rate=px_rate,
                                 px_dropout=px_dropout)

        return {**decoder_variables, **library_variables, **z_variables}