Exemple #1
0
    def inference(self, x, batch_index=None, y=None, n_samples=1):

        x_ = x
        if self.reconstruction_loss == "nb" and self.log_variational:
            library_nb = torch.log(x_.sum(dim=-1)).reshape(-1, 1)
        elif self.reconstruction_loss == "nb" and not self.log_variational:
            library_nb = (x_.sum(dim=-1)).reshape(-1, 1)
        if self.log_variational:
            x_ = torch.log(1 + x_)

        # Sampling
        qz_m, qz_v, z = self.z_encoder(x_, y)
        ql_m, ql_v, library = self.l_encoder(x_)
        if self.reconstruction_loss == "nb":
            library = library_nb

        if n_samples > 1:
            qz_m = qz_m.unsqueeze(0).expand(
                (n_samples, qz_m.size(0), qz_m.size(1)))
            qz_v = qz_v.unsqueeze(0).expand(
                (n_samples, qz_v.size(0), qz_v.size(1)))
            z = Normal(qz_m, qz_v.sqrt()).sample()
            ql_m = ql_m.unsqueeze(0).expand(
                (n_samples, ql_m.size(0), ql_m.size(1)))
            ql_v = ql_v.unsqueeze(0).expand(
                (n_samples, ql_v.size(0), ql_v.size(1)))
            library = Normal(ql_m, ql_v.sqrt()).sample()

        px_scale, px_r, px_rate, px_dropout = self.decoder(
            self.dispersion, z, library, batch_index, y)
        if self.dispersion == "gene-label":
            px_r = F.linear(
                one_hot(y, self.n_labels),
                self.px_r)  # px_r gets transposed - last dimension is nb genes
        elif self.dispersion == "gene-batch":
            px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r)
        elif self.dispersion == "gene":
            px_r = self.px_r
        px_r = torch.exp(px_r)

        return dict(
            px_scale=px_scale,
            px_r=px_r,
            px_rate=px_rate,
            px_dropout=px_dropout,
            qz_m=qz_m,
            qz_v=qz_v,
            z=z,
            ql_m=ql_m,
            ql_v=ql_v,
            library=library,
        )
Exemple #2
0
    def forward(self, dispersion: str, z: torch.Tensor, library: torch.Tensor,
                *cat_list: int):
        # The decoder returns values for the parameters of the ZINB distribution
        p1_ = self.factor_regressor(z)
        if self.n_batches > 1:
            one_hot_cat = one_hot(cat_list[0], self.n_batches)[:, :-1]
            p2_ = self.batch_regressor(one_hot_cat)
            raw_px_scale = p1_ + p2_
        else:
            raw_px_scale = p1_

        px_scale = torch.softmax(raw_px_scale, dim=-1)
        px_dropout = self.px_dropout_decoder(z)
        px_rate = torch.exp(library) * px_scale
        px_r = None

        return px_scale, px_r, px_rate, px_dropout
Exemple #3
0
    def forward(self, x: torch.Tensor, *cat_list: int, instance_id: int = 0):
        r"""Forward computation on ``x``.

        :param x: tensor of values with shape ``(n_in,)``
        :param cat_list: list of category membership(s) for this sample
        :param instance_id: Use a specific conditional instance normalization (batchnorm)
        :return: tensor of shape ``(n_out,)``
        :rtype: :py:class:`torch.Tensor`
        """
        one_hot_cat_list = [
        ]  # for generality in this list many indices useless.
        assert len(self.n_cat_list) <= len(
            cat_list
        ), "nb. categorical args provided doesn't match init. params."
        for n_cat, cat in zip(self.n_cat_list, cat_list):
            assert not (n_cat and cat is None
                        ), "cat not provided while n_cat != 0 in init. params."
            if n_cat > 1:  # n_cat = 1 will be ignored - no additional information
                if cat.size(1) != n_cat:
                    one_hot_cat = one_hot(cat, n_cat)
                else:
                    one_hot_cat = cat  # cat has already been one_hot encoded
                one_hot_cat_list += [one_hot_cat]
        for layers in self.fc_layers:
            for layer in layers:
                if layer is not None:
                    if isinstance(layer, nn.BatchNorm1d):
                        if x.dim() == 3:
                            x = torch.cat([(layer(slice_x)).unsqueeze(0)
                                           for slice_x in x],
                                          dim=0)
                        else:
                            x = layer(x)
                    else:
                        if isinstance(layer, nn.Linear):
                            if x.dim() == 3:
                                one_hot_cat_list = [
                                    o.unsqueeze(0).expand(
                                        (x.size(0), o.size(0), o.size(1)))
                                    for o in one_hot_cat_list
                                ]
                            x = torch.cat((x, *one_hot_cat_list), dim=-1)
                        x = layer(x)
        return x
Exemple #4
0
    def inference(self,
                  x,
                  batch_index=None,
                  y=None,
                  local_l_mean=None,
                  local_l_var=None,
                  update=False,
                  n_samples=1):
        x_ = x
        if len(x_) != 2:
            raise ValueError(
                "Input training data should be 2 data types(RNA and ATAC),"
                "but input was only {}.format(len(x_))")
        x_rna = x_[0]
        x_atac = x_[1]
        if self.log_variational:
            x_rna = torch.log(1 + x_rna)
            # x_atac = torch.log(1 + x_atac)

        # Sampling
        qz_rna_m, qz_rna_v, rna_z = self.RNA_encoder(x_rna, y)
        qz_atac_m, qz_atac_v, atac_z = self.ATAC_encoder(x_atac, y)
        qz_m, qz_v, z = self.RNA_ATAC_encoder([x_rna, x_atac], y)
        if self.isLibrary:
            ql_m, ql_v, l_z = self.l_encoder(x_rna, y)
        gamma, mu_c, var_c, pi = self.get_gamma(
            z, update)  # , self.n_centroids, c_params)
        index = torch.argmax(gamma, dim=1)
        # mu_c_max = torch.tensor([])
        # var_c_max = torch.tensor([])

        # for index1 in range(len(index)):
        #    mu_c_max = torch.cat((mu_c_max, mu_c[index1,:,index[index1]].float()),1)
        #    var_c_max = torch.cat((var_c_max, var_c[index1,:,index[index1]].float()),1)

        index1 = [i for i in range(len(index))]
        mu_c_max = mu_c[index1, :, index]
        var_c_max = var_c[index1, :, index]
        z_c_max = reparameterize_gaussian(mu_c_max, var_c_max)

        libary_scale = reparameterize_gaussian(local_l_mean, local_l_var)
        if self.isLibrary:
            libary_scale = l_z
        # decoder
        p_rna_scale, p_rna_r, p_rna_rate, p_rna_dropout, p_atac_scale, p_atac_r, p_atac_mean, p_atac_dropout \
            = self.RNA_ATAC_decoder(z, z_c_max, y, libary_scale=libary_scale, gamma=gamma)
        # = self.RNA_ATAC_decoder(z, z_c_max, y, gamma=gamma)

        rec_rna_mu, rec_rna_v, rec_rna_z = self.RNA_encoder(p_rna_rate, y)
        gamma_rna_rec, _, _, _ = self.get_gamma(rec_rna_z)
        rec_atac_mu, rec_atac_v, rec_atac_z = self.ATAC_encoder(p_atac_mean, y)
        gamma_atac_rec, _, _, _ = self.get_gamma(rec_atac_z)

        if self.dispersion == "gene-label":
            p_rna_r = F.linear(
                one_hot(y, self.n_labels),
                self.px_r)  # px_r gets transposed - last dimension is nb genes
            p_atac_r = F.linear(one_hot(y, self.n_labels), self.p_atac_r)
        elif self.dispersion == "gene-batch":
            p_rna_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r)
            p_atac_r = F.linear(one_hot(batch_index, self.n_batch),
                                self.p_atac_r)
        elif self.dispersion == "gene":
            p_rna_r = self.px_r
            p_atac_r = self.p_atac_r

        p_rna_r = torch.exp(p_rna_r)
        p_atac_r = torch.exp(p_atac_r)

        return dict(
            p_rna_scale=p_rna_scale,
            p_rna_r=p_rna_r,
            p_rna_rate=p_rna_rate,
            p_rna_dropout=p_rna_dropout,
            p_atac_scale=p_atac_scale,
            p_atac_r=p_atac_r,
            p_atac_mean=p_atac_mean,
            p_atac_dropout=p_atac_dropout,
            qz_rna_m=qz_rna_m,
            qz_rna_v=qz_rna_v,
            rna_z=rna_z,
            qz_atac_m=qz_atac_m,
            qz_atac_v=qz_atac_v,
            atac_z=atac_z,
            qz_m=qz_m,
            qz_v=qz_v,
            z=z,
            mu_c=mu_c,
            var_c=var_c,
            gamma=gamma,
            pi=pi,
            mu_c_max=mu_c_max,
            var_c_max=var_c_max,
            z_c_max=z_c_max,
            gamma_rna_rec=gamma_rna_rec,
            gamma_atac_rec=gamma_atac_rec,
            rec_atac_mu=rec_atac_mu,
            rec_atac_v=rec_atac_v,
            rec_rna_mu=rec_rna_mu,
            rec_rna_v=rec_rna_v,
        )
Exemple #5
0
    def inference(self, x, batch_index=None, y=None, local_l_mean=None, local_l_var=None, update=False, n_samples=1):
        x_ = x
        if len(x_) != 2:
            raise ValueError("Input training data should be 2 data types(RNA and ATAC),"
                             "but input was only {}.format(len(x_))"
                             )
        x_rna = x_[0]
        x_atac = x_[1]
        libary_atac = torch.log(x_[1].sum(dim=-1)).reshape(-1, 1)
        libary_rna = torch.log(x_[0].sum(dim=-1)).reshape(-1, 1)
        if self.log_variational:
            x_rna = torch.log(1 + x_rna)
            x_atac = torch.log(1 + x_atac)

        # Sampling
        if self.isLibrary:
            ql_m, ql_v, l_z = self.l_encoder(x_rna, batch_index)
        qz_rna_m, qz_rna_v, rna_z = self.RNA_encoder(x_rna, batch_index)
        qz_atac_m, qz_atac_v, atac_z = self.ATAC_encoder(x_atac, batch_index)
        qz_m, qz_v, z = self.RNA_ATAC_encoder([x_rna, x_atac], batch_index)

        qz_joint_mu = self.concatenter(torch.cat((qz_rna_m, qz_atac_m), 1))
        qz_joint_v = self.concatenter(torch.cat((torch.log(qz_rna_v), torch.log(qz_atac_v)), 1))
        qz_joint_v = torch.exp(qz_joint_v)
        qz_joint_z = Normal(qz_joint_mu, qz_joint_v.sqrt()).rsample()
        gamma_joint, _, _, _ = self.get_gamma(qz_joint_z)

        gamma, mu_c, var_c, pi = self.get_gamma(z, update)  # , self.n_centroids, c_params)
        index = torch.argmax(gamma, dim=1)

        index1 = [i for i in range(len(index))]
        mu_c_max = mu_c[index1, :, index]
        var_c_max = var_c[index1, :, index]
        z_c_max = reparameterize_gaussian(mu_c_max, var_c_max)

        libary_scale = reparameterize_gaussian(local_l_mean, local_l_var)
        if self.isLibrary:
            libary_scale = libary_rna
        # decoder
        p_rna_scale, p_rna_r, p_rna_rate, p_rna_dropout, p_atac_scale, p_atac_r, p_atac_mean, p_atac_dropout \
            = self.RNA_ATAC_decoder(z, z_c_max, batch_index, libary_scale=libary_scale, gamma=gamma, libary_atac=libary_atac)
        # classifer
        if self.classifer_num > 0 and y is not None:
            classifer_pred = self.classifer(z)
            classifer_loss = -100*(
                                one_hot(y, self.classifer_num)*torch.log(classifer_pred+1.0e-10)
                               ).sum(dim=-1)

        if self.log_variational:
            p_rna_rate_norm =  torch.log(1 + p_rna_rate)
            p_atac_mean_norm = torch.log(1 + p_atac_mean)
        rec_rna_mu, rec_rna_v, rec_rna_z = self.RNA_encoder(p_rna_rate_norm, batch_index)
        gamma_rna_rec, _, _, _ = self.get_gamma(rec_rna_z)
        rec_atac_mu, rec_atac_v, rec_atac_z = self.ATAC_encoder(p_atac_mean_norm, batch_index)
        gamma_atac_rec, _, _, _ = self.get_gamma(rec_atac_z)
        rec_joint_mu = self.concatenter(torch.cat((rec_rna_mu, rec_atac_mu), 1))
        rec_joint_v = self.concatenter(torch.cat((torch.log(rec_rna_v), torch.log(rec_atac_v)), 1))
        rec_joint_v = torch.exp(rec_joint_v)
        rec_joint_z = Normal(rec_joint_mu, rec_joint_v.sqrt()).rsample()
        gamma_joint_rec, _, _, _ = self.get_gamma(rec_joint_z)

        if self.dispersion == "gene-label":
            p_rna_r = F.linear(
                one_hot(y, self.n_labels), self.px_r
            )  # px_r gets transposed - last dimension is nb genes
            p_atac_r = F.linear(
                one_hot(y, self.n_labels), self.p_atac_r
            )
        elif self.dispersion == "gene-batch":
            p_rna_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r)
            p_atac_r = F.linear(one_hot(batch_index, self.n_batch), self.p_atac_r)
        elif self.dispersion == "gene":
            p_rna_r = self.px_r
            p_atac_r = self.p_atac_r

        p_rna_r = torch.exp(p_rna_r)
        p_atac_r = torch.exp(p_atac_r)

        return dict(
            p_rna_scale=p_rna_scale,
            p_rna_r=p_rna_r,
            p_rna_rate=p_rna_rate,
            p_rna_dropout=p_rna_dropout,
            p_atac_scale=p_atac_scale,
            p_atac_r=p_atac_r,
            p_atac_mean=p_atac_mean,
            p_atac_dropout=p_atac_dropout,
            qz_rna_m=qz_rna_m,
            qz_rna_v=qz_rna_v,
            rna_z=rna_z,
            qz_atac_m=qz_atac_m,
            qz_atac_v=qz_atac_v,
            atac_z=atac_z,
            qz_m=qz_m,
            qz_v=qz_v,
            z=z,
            mu_c=mu_c,
            var_c=var_c,
            gamma=gamma,
            pi=pi,
            mu_c_max=mu_c_max,
            var_c_max=var_c_max,
            z_c_max=z_c_max,
            gamma_rna_rec=gamma_rna_rec,
            gamma_atac_rec=gamma_atac_rec,
            rec_atac_mu=rec_atac_mu,
            rec_atac_v=rec_atac_v,
            rec_rna_mu=rec_rna_mu,
            rec_rna_v=rec_rna_v,
            ql_m=ql_m,
            ql_v=ql_v,
            l_z=l_z,
            rec_joint_mu=rec_joint_mu,
            rec_joint_v=rec_joint_v,
            rec_joint_z=rec_joint_z,
            gamma_joint_rec=gamma_joint_rec,
            qz_joint_mu=qz_joint_mu,
            qz_joint_v=qz_joint_v,
            qz_joint_z=qz_joint_z,
            gamma_joint=gamma_joint,
            classifer_loss=classifer_loss if self.classifer_num > 0 else 0,
        )