示例#1
0
    def inference(self, x, batch_index=None, y=None, n_samples=1):
        x_ = x
        if self.log_variational:
            x_ = torch.log(1 + x_)

        # Sampling
        qz_m, qz_v, z = self.z_encoder(x_, y)
        px_r, px_rate = self.decoder(self.dispersion, z, batch_index, y)

        if self.dispersion == "gene-label":
            px_r = F.linear(
                one_hot(y, self.n_labels),
                self.px_r)  # px_r gets transposed - last dimension is nb genes
        elif self.dispersion == "gene-batch":
            px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r)
        elif self.dispersion == "gene":
            px_r = self.px_r
        px_rate = torch.exp(px_rate)
        px_r = torch.exp(px_r)

        return dict(
            px_r=px_r,
            px_rate=px_rate,
            qz_m=qz_m,
            qz_v=qz_v,
            z=z,
        )
示例#2
0
    def decode(
        self,
        z: torch.Tensor,
        mode: int,
        library: torch.Tensor,
        batch_index: Optional[torch.Tensor] = None,
        y: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        px_scale, px_r, px_rate, px_dropout = self.decoder(
            z, mode, library, self.dispersion, batch_index, y
        )

        if self.dispersion == "gene-label":
            px_r = F.linear(one_hot(y, self.n_labels), self.px_r)
        elif self.dispersion == "gene-batch":
            px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r)
        elif self.dispersion == "gene":
            px_r = self.px_r.view(1, self.px_r.size(0))
        px_r = torch.exp(px_r)

        px_scale = px_scale / torch.sum(
            px_scale[:, self.indices_mappings[mode]], dim=1
        ).view(-1, 1)
        px_rate = px_scale * torch.exp(library)

        return px_scale, px_r, px_rate, px_dropout
示例#3
0
    def _reconstruction_loss(self,
                             x,
                             px_rate,
                             px_r,
                             px_dropout,
                             batch_index,
                             y,
                             mode="scRNA",
                             weighting=1):
        if self.dispersion == "gene-label":
            px_r = F.linear(
                one_hot(y, self.n_labels),
                self.px_r)  # px_r gets transposed - last dimension is nb genes
        elif self.dispersion == "gene-batch":
            px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r)
        elif self.dispersion == "gene":
            px_r = self.px_r

        # Reconstruction Loss
        if mode == "scRNA":
            if self.reconstruction_loss == 'zinb':
                reconst_loss = -log_zinb_positive(x, px_rate, torch.exp(px_r),
                                                  px_dropout)
            elif self.reconstruction_loss == 'nb':
                reconst_loss = -log_nb_positive(x, px_rate, torch.exp(px_r))

        else:
            if self.reconstruction_loss_fish == 'poisson':
                reconst_loss = -torch.sum(Poisson(px_rate).log_prob(x), dim=1)
            elif self.reconstruction_loss_fish == 'gaussian':
                reconst_loss = -torch.sum(Normal(px_rate, 10).log_prob(x),
                                          dim=1)
        return reconst_loss
示例#4
0
    def reshape_bernoulli(
        self,
        bernoulli_params: torch.Tensor,
        batch_index: Optional[torch.Tensor] = None,
        y: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if self.zero_inflation == "gene-label":
            one_hot_label = one_hot(y, self.n_labels)
            # If we sampled several random Bernoulli parameters
            if len(bernoulli_params.shape) == 2:
                bernoulli_params = F.linear(one_hot_label, bernoulli_params)
            else:
                bernoulli_params_res = []
                for sample in range(bernoulli_params.shape[0]):
                    bernoulli_params_res.append(
                        F.linear(one_hot_label, bernoulli_params[sample])
                    )
                bernoulli_params = torch.stack(bernoulli_params_res)
        elif self.zero_inflation == "gene-batch":
            one_hot_batch = one_hot(batch_index, self.n_batch)
            if len(bernoulli_params.shape) == 2:
                bernoulli_params = F.linear(one_hot_batch, bernoulli_params)
            # If we sampled several random Bernoulli parameters
            else:
                bernoulli_params_res = []
                for sample in range(bernoulli_params.shape[0]):
                    bernoulli_params_res.append(
                        F.linear(one_hot_batch, bernoulli_params[sample])
                    )
                bernoulli_params = torch.stack(bernoulli_params_res)

        return bernoulli_params
示例#5
0
    def inference(self, x, batch_index=None, y=None, n_samples=1, force_batch=None):
        x_ = x
        if self.log_variational:
            x_ = torch.log(1 + x_)

        if force_batch is not None:
            batch_index = torch.zeros_like(batch_index).fill_(force_batch)

        # Sampling
        qz_m, qz_v, z = self.z_encoder(x_, y)
        ql_m, ql_v, library = self.l_encoder(x_)

        if n_samples > 1:
            qz_m = qz_m.unsqueeze(0).expand((n_samples, qz_m.size(0), qz_m.size(1)))
            qz_v = qz_v.unsqueeze(0).expand((n_samples, qz_v.size(0), qz_v.size(1)))
            z = Normal(qz_m, qz_v.sqrt()).sample()
            ql_m = ql_m.unsqueeze(0).expand((n_samples, ql_m.size(0), ql_m.size(1)))
            ql_v = ql_v.unsqueeze(0).expand((n_samples, ql_v.size(0), ql_v.size(1)))
            library = Normal(ql_m, ql_v.sqrt()).sample()

        px_scale, px_r, px_rate, px_dropout = self.decoder(self.dispersion, z, library, batch_index, y)
        if self.dispersion == "gene-label":
            px_r = F.linear(one_hot(y, self.n_labels), self.px_r)  # px_r gets transposed - last dimension is nb genes
        elif self.dispersion == "gene-batch":
            px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r)
        elif self.dispersion == "gene":
            px_r = self.px_r
        px_r = torch.exp(px_r)

        return px_scale, px_r, px_rate, px_dropout, qz_m, qz_v, z, ql_m, ql_v, library
示例#6
0
文件: vae.py 项目: yynst2/scVI
    def inference(self,
                  x,
                  batch_index=None,
                  y=None,
                  n_samples=1,
                  transform_batch=None) -> Dict[str, torch.Tensor]:
        """Helper function used in forward pass
        """
        x_ = x
        if self.log_variational:
            x_ = torch.log(1 + x_)

        # Sampling
        qz_m, qz_v, z = self.z_encoder(x_, y)
        ql_m, ql_v, library = self.l_encoder(x_)

        if n_samples > 1:
            qz_m = qz_m.unsqueeze(0).expand(
                (n_samples, qz_m.size(0), qz_m.size(1)))
            qz_v = qz_v.unsqueeze(0).expand(
                (n_samples, qz_v.size(0), qz_v.size(1)))
            # when z is normal, untran_z == z
            untran_z = Normal(qz_m, qz_v.sqrt()).sample()
            z = self.z_encoder.z_transformation(untran_z)
            ql_m = ql_m.unsqueeze(0).expand(
                (n_samples, ql_m.size(0), ql_m.size(1)))
            ql_v = ql_v.unsqueeze(0).expand(
                (n_samples, ql_v.size(0), ql_v.size(1)))
            library = Normal(ql_m, ql_v.sqrt()).sample()

        if transform_batch is not None:
            dec_batch_index = transform_batch * torch.ones_like(batch_index)
        else:
            dec_batch_index = batch_index

        px_scale, px_r, px_rate, px_dropout = self.decoder(
            self.dispersion, z, library, dec_batch_index, y)
        if self.dispersion == "gene-label":
            px_r = F.linear(
                one_hot(y, self.n_labels),
                self.px_r)  # px_r gets transposed - last dimension is nb genes
        elif self.dispersion == "gene-batch":
            px_r = F.linear(one_hot(dec_batch_index, self.n_batch), self.px_r)
        elif self.dispersion == "gene":
            px_r = self.px_r
        px_r = torch.exp(px_r)

        return dict(
            px_scale=px_scale,
            px_r=px_r,
            px_rate=px_rate,
            px_dropout=px_dropout,
            qz_m=qz_m,
            qz_v=qz_v,
            z=z,
            ql_m=ql_m,
            ql_v=ql_v,
            library=library,
        )
示例#7
0
    def inference(self,
                  x,
                  batch_index=None,
                  y=None,
                  mode="scRNA",
                  weighting=1):
        x_ = x
        if self.log_variational:
            x_ = torch.log(1 + x_)
        # Sampling
        if mode == "scRNA":
            qz_m, qz_v, z = self.z_encoder(x_)
            library = torch.log(torch.sum(x, dim=1)).view(-1, 1)
            batch_index = torch.zeros_like(library)
        if mode == "smFISH":
            qz_m, qz_v, z = self.z_encoder_fish(x_[:, self.indexes_to_keep])
            library = torch.log(torch.sum(x[:, self.indexes_to_keep],
                                          dim=1)).view(-1, 1)
            batch_index = torch.ones_like(library)
        if self.model_library:
            if mode == "scRNA":
                ql_m, ql_v, library = self.l_encoder(x_)
            elif mode == "smFISH":
                ql_m, ql_v, library = self.l_encoder_fish(
                    x_[:, self.indexes_to_keep])
        else:
            ql_m, ql_v = None, None

        qz_m, qz_v, z = self.z_final_encoder(z)
        px_scale, px_r, px_rate, px_dropout = self.decoder(
            self.dispersion, z, library, batch_index)

        # rescaling the expected frequencies
        if mode == "smFISH":
            if self.model_library:
                px_rate = px_scale[:,
                                   self.indexes_to_keep] * torch.exp(library)
            else:
                px_scale = px_scale[:, self.indexes_to_keep] / torch.sum(
                    px_scale[:, self.indexes_to_keep], dim=1).view(-1, 1)
                px_rate = px_scale * torch.exp(library)

        if self.dispersion == "gene-label":
            px_r = F.linear(
                one_hot(y, self.n_labels),
                self.px_r)  # px_r gets transposed - last dimension is nb genes
        elif self.dispersion == "gene-batch":
            px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r)
        elif self.dispersion == "gene":
            px_r = self.px_r

        return px_scale, px_r, px_rate, px_dropout, qz_m, qz_v, z, ql_m, ql_v, library
示例#8
0
    def forward(self,
                x,
                local_l_mean,
                local_l_var,
                batch_index=None,
                y=None):  # same signature as loss
        # Parameters for z latent distribution
        x_ = x
        if self.log_variational:
            x_ = torch.log(1 + x_)

        # Sampling
        qz_m, qz_v, z = self.z_encoder(x_)
        ql_m, ql_v, library = self.l_encoder(x_)

        if self.dispersion == "gene-cell":
            px_scale, self.px_r, px_rate, px_dropout = self.decoder(
                self.dispersion, z, library, batch_index)
        else:  # self.dispersion == "gene", "gene-batch",  "gene-label"
            px_scale, px_rate, px_dropout = self.decoder(
                self.dispersion, z, library, batch_index)

        if self.dispersion == "gene-label":
            px_r = F.linear(
                one_hot(y, self.n_labels),
                self.px_r)  # px_r gets transposed - last dimension is nb genes
        elif self.dispersion == "gene-batch":
            px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r)
        else:
            px_r = self.px_r

        # Reconstruction Loss
        if self.reconstruction_loss == 'zinb':
            reconst_loss = -log_zinb_positive(x, px_rate, torch.exp(px_r),
                                              px_dropout)
        elif self.reconstruction_loss == 'nb':
            reconst_loss = -log_nb_positive(x, px_rate, torch.exp(px_r))

        # KL Divergence
        mean = torch.zeros_like(qz_m)
        scale = torch.ones_like(qz_v)

        kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)),
                             Normal(mean, scale)).sum(dim=1)
        kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)),
                             Normal(local_l_mean,
                                    torch.sqrt(local_l_var))).sum(dim=1)
        kl_divergence = kl_divergence_z + kl_divergence_l

        return reconst_loss, kl_divergence
示例#9
0
    def _reconstruction_loss(self, x, px_rate, px_r, px_dropout, batch_index, y):
        if self.dispersion == "gene-label":
            px_r = F.linear(one_hot(y, self.n_labels), self.px_r)  # px_r gets transposed - last dimension is nb genes
        elif self.dispersion == "gene-batch":
            px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r)
        elif self.dispersion == "gene":
            px_r = self.px_r

        # Reconstruction Loss
        if self.reconstruction_loss == 'zinb':
            reconst_loss = -log_zinb_positive(x, px_rate, torch.exp(px_r), px_dropout)
        elif self.reconstruction_loss == 'nb':
            reconst_loss = -log_nb_positive(x, px_rate, torch.exp(px_r))
        return reconst_loss
示例#10
0
    def forward(self, x: torch.Tensor, *cat_list: int):
        r"""Forward computation on ``x``.

        :param x: tensor of values with shape ``(n_in,)``
        :param cat_list: list of category membership(s) for this sample
        :return: tensor of shape ``(n_out,)``
        :rtype: :py:class:`torch.Tensor`
        """
        one_hot_cat_list = []  # for generality in this list many indices useless.
        assert len(self.n_cat_list) <= len(cat_list), "nb. categorical args provided doesn't match init. params."
        for n_cat, cat in zip(self.n_cat_list, cat_list):
            assert not (n_cat and cat is None), "cat not provided while n_cat != 0 in init. params."
            if n_cat > 1:  # n_cat = 1 will be ignored - no additional information
                if cat.size(1) != n_cat:
                    one_hot_cat = one_hot(cat, n_cat)
                else:
                    one_hot_cat = cat  # cat has already been one_hot encoded
                one_hot_cat_list += [one_hot_cat]
        for layers in self.fc_layers:
            for layer in layers:
                if layer is not None:
                    if isinstance(layer, nn.BatchNorm1d):
                        if x.dim() == 3:
                            x = torch.cat([(layer(slice_x)).unsqueeze(0) for slice_x in x], dim=0)
                        else:
                            x = layer(x)
                    else:
                        if isinstance(layer, nn.Linear):
                            if x.dim() == 3:
                                one_hot_cat_list = [o.unsqueeze(0).expand((x.size(0), o.size(0), o.size(1)))
                                                    for o in one_hot_cat_list]
                            x = torch.cat((x, *one_hot_cat_list), dim=-1)
                        x = layer(x)
        return x
示例#11
0
文件: vae.py 项目: jimmayxu/scVI
    def inference(self, x, batch_index=None, y=None, n_samples=1, train_library=True):
        x_ = x
        if self.log_variational:
            x_ = torch.log(1 + x_)

        # Sampling
        qz_m, qz_v, z = self.z_encoder(x_, y)
        ql_m, ql_v, library = self.l_encoder(x_)

        if n_samples > 1:
            assert not self.z_full_cov
            # TODO: Check no issues when full cov
            qz_m = qz_m.unsqueeze(0).expand([n_samples] + list(qz_m.size()))
            qz_v = qz_v.unsqueeze(0).expand([n_samples] + list(qz_v.size()))
            ql_m = ql_m.unsqueeze(0).expand([n_samples] + list(ql_m.size()))
            ql_v = ql_v.unsqueeze(0).expand([n_samples] + list(ql_v.size()))
            z = self.z_encoder.sample(qz_m, qz_v)
            library = self.l_encoder.sample(ql_m, ql_v)

        # library = torch.clamp(library, max=14)
        # if (library >= 14).any():
        #     print('TOTOTATA')

        if not train_library:
            library = 1.0
        px_scale, px_r, px_rate, px_dropout = self.decoder(self.dispersion, z, library, batch_index, y)
        if self.dispersion == "gene-label":
            px_r = F.linear(one_hot(y, self.n_labels), self.px_r)  # px_r gets transposed - last dimension is nb genes
        elif self.dispersion == "gene-batch":
            px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r)
        elif self.dispersion == "gene":
            px_r = self.px_r
        px_r = torch.exp(px_r)

        return dict(
            px_scale=px_scale,
            px_r=px_r,
            px_rate=px_rate,
            px_dropout=px_dropout,
            qz_m=qz_m,
            qz_v=qz_v,
            z=z,
            ql_m=ql_m,
            ql_v=ql_v,
            library=library
        )
示例#12
0
    def impute_from_z(self, fixed_batch_indices, fixed_l, sample=False):
        for tensors in self:
            sample_batch, local_l_mean, local_l_var, batch_index, label = tensors
            if not sample:
                if self.model.log_variational:
                    sample_batch = torch.log(1 + sample_batch)
                z = [self.model.z_encoder(sample_batch)[0]]
            else:
                z = [self.model.sample_from_posterior_z(sample_batch)]
            px_scale, px_r, px_rate, px_dropout = self.model.decoder(
                self.model.dispersion, z, fixed_l, fixed_batch_indices)
            if self.model.dispersion == "gene-label":
                px_r = F.linear(
                    one_hot(y, self.model.n_labels), self.px_r
                )  # px_r gets transposed - last dimension is nb genes
            elif self.dispersion == "gene-batch":
                px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r)
            elif self.dispersion == "gene":
                px_r = self.px_r
            px_r = torch.exp(px_r)

        return px_r
示例#13
0
 def forward(self, x, *os):
     one_hot_os = []
     for i, o in enumerate(os):
         if o is not None and self.n_cat_list[i]:
             one_hot_o = o
             if o.size(1) != self.n_cat_list[i]:
                 one_hot_o = one_hot(o, self.n_cat_list[i])
             elif o.size(1) == 1 and self.n_cat_list[i] == 1:
                 one_hot_o = o.type(torch.float32)
             one_hot_os += [one_hot_o]
     for layer in self.fc_layers:
         x = layer(torch.cat((x,) + tuple(one_hot_os), 1))
     return x
    def loss_discriminator(self,
                           z,
                           batch_index,
                           predict_true_class=True,
                           return_details=True):

        n_classes = self.gene_dataset.n_batches
        cls_logits = torch.nn.LogSoftmax(dim=1)(self.discriminator(z))

        if predict_true_class:
            cls_target = one_hot(batch_index, n_classes)
        else:
            one_hot_batch = one_hot(batch_index, n_classes)
            cls_target = torch.zeros_like(one_hot_batch)
            # place zeroes where true label is
            cls_target.masked_scatter_(
                ~one_hot_batch.bool(),
                torch.ones_like(one_hot_batch) / (n_classes - 1))

        l_soft = cls_logits * cls_target
        loss = -l_soft.sum(dim=1).mean()

        return loss
示例#15
0
文件: modules.py 项目: MTreppner/scVI
    def forward(self, dispersion: str, z: torch.Tensor, library: torch.Tensor,
                *cat_list: int):
        # The decoder returns values for the parameters of the ZINB distribution
        p1_ = self.factor_regressor(z)
        if self.n_batches > 1:
            one_hot_cat = one_hot(cat_list[0], self.n_batches)[:, :-1]
            p2_ = self.batch_regressor(one_hot_cat)
            raw_px_scale = p1_ + p2_
        else:
            raw_px_scale = p1_

        px_scale = torch.softmax(raw_px_scale, dim=-1)
        px_dropout = self.px_dropout_decoder(z)
        px_rate = torch.exp(library) * px_scale
        px_r = None

        return px_scale, px_r, px_rate, px_dropout
示例#16
0
 def forward(self, x, *cat_list):
     one_hot_cat_list = [
     ]  # for generality in this list many indices useless.
     assert len(self.n_cat_list) <= len(
         cat_list
     ), "nb. categorical args provided doesn't match init. params."
     for n_cat, cat in zip(self.n_cat_list, cat_list):
         assert not (n_cat and cat is None
                     ), "cat not provided while n_cat != 0 in init. params."
         if n_cat > 1:  # n_cat = 1 will be ignored - no additional information
             if cat.size(1) != n_cat:
                 one_hot_cat = one_hot(cat, n_cat)
             else:
                 one_hot_cat = cat  # cat has already been one_hot encoded
             one_hot_cat_list += [one_hot_cat]
     for layers in self.fc_layers:
         for layer in layers:
             if isinstance(layer, nn.Linear):
                 x = torch.cat((x, *one_hot_cat_list), 1)
             x = layer(x)
     return x
示例#17
0
 def forward(self, x, *cat_list):
     one_hot_cat_list = []  # for generality in this list many indices useless.
     assert len(self.n_cat_list) <= len(cat_list), "nb. categorical args provided doesn't match init. params."
     for n_cat, cat in zip(self.n_cat_list, cat_list):
         assert not (n_cat and cat is None), "cat not provided while n_cat != 0 in init. params."
         if n_cat > 1:  # n_cat = 1 will be ignored - no additional information
             if cat.size(1) != n_cat:
                 one_hot_cat = one_hot(cat, n_cat)
             else:
                 one_hot_cat = cat  # cat has already been one_hot encoded
             one_hot_cat_list += [one_hot_cat]
     for layers in self.fc_layers:
         for layer in layers:
             if isinstance(layer, nn.BatchNorm1d) and x.dim() == 3:
                 x = torch.cat([(layer(slice_x)).unsqueeze(0) for slice_x in x], dim=0)
             else:
                 if isinstance(layer, nn.Linear):
                     if x.dim() == 3:
                         one_hot_cat_list = [o.unsqueeze(0).expand((x.size(0), o.size(0), o.size(1)))
                                             for o in one_hot_cat_list]
                     x = torch.cat((x, *one_hot_cat_list), dim=-1)
                 x = layer(x)
     return x
示例#18
0
文件: vaec.py 项目: jstjohn/scVI-dev
    def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None):
        is_labelled = False if y is None else True

        # Prepare for sampling
        xs, ys = (x, y)

        # Enumerate choices of label
        if not is_labelled:
            ys = enumerate_discrete(xs, self.n_labels)
            xs = xs.repeat(self.n_labels, 1)
            if batch_index is not None:
                batch_index = batch_index.repeat(self.n_labels, 1)
            local_l_var = local_l_var.repeat(self.n_labels, 1)
            local_l_mean = local_l_mean.repeat(self.n_labels, 1)
        else:
            ys = one_hot(ys, self.n_labels)

        xs_ = xs
        if self.log_variational:
            xs_ = torch.log(1 + xs_)

        # Sampling
        qz_m, qz_v, z = self.z_encoder(xs_, ys)
        ql_m, ql_v, library = self.l_encoder(xs_)

        if self.dispersion == "gene-cell":
            px_scale, self.px_r, px_rate, px_dropout = self.decoder(
                self.dispersion, z, library, batch_index, y=ys)
        elif self.dispersion == "gene":
            px_scale, px_rate, px_dropout = self.decoder(self.dispersion,
                                                         z,
                                                         library,
                                                         batch_index,
                                                         y=ys)

        # Reconstruction Loss
        if self.reconstruction_loss == 'zinb':
            reconst_loss = -log_zinb_positive(xs, px_rate, torch.exp(
                self.px_r), px_dropout)
        elif self.reconstruction_loss == 'nb':
            reconst_loss = -log_nb_positive(xs, px_rate, torch.exp(self.px_r))

        # KL Divergence
        mean = torch.zeros_like(qz_m)
        scale = torch.ones_like(qz_v)

        kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)),
                             Normal(mean, scale)).sum(dim=1)
        kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)),
                             Normal(local_l_mean,
                                    torch.sqrt(local_l_var))).sum(dim=1)
        kl_divergence = kl_divergence_z + kl_divergence_l

        if is_labelled:
            return reconst_loss, kl_divergence

        reconst_loss = reconst_loss.view(self.n_labels, -1)
        kl_divergence = kl_divergence.view(self.n_labels, -1)

        if self.log_variational:
            x_ = torch.log(1 + x)

        probs = self.classifier(x_)
        reconst_loss = (reconst_loss.t() * probs).sum(dim=1)
        kl_divergence = (kl_divergence.t() * probs).sum(dim=1)
        kl_divergence += kl(Multinomial(probs=probs),
                            Multinomial(probs=self.y_prior))

        return reconst_loss, kl_divergence
示例#19
0
    def inference(
        self,
        x: torch.Tensor,
        y: torch.Tensor,
        batch_index: Optional[torch.Tensor] = None,
        label: Optional[torch.Tensor] = None,
        n_samples=1,
        transform_batch: Optional[int] = None,
    ) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]:
        """ Internal helper function to compute necessary inference quantities

         We use the dictionary ``px_`` to contain the parameters of the ZINB/NB for genes.
         The rate refers to the mean of the NB, dropout refers to Bernoulli mixing parameters.
         `scale` refers to the quanity upon which differential expression is performed. For genes,
         this can be viewed as the mean of the underlying gamma distribution.

         We use the dictionary ``py_`` to contain the parameters of the Mixture NB distribution for proteins.
         `rate_fore` refers to foreground mean, while `rate_back` refers to background mean. ``scale`` refers to
         foreground mean adjusted for background probability and scaled to reside in simplex.
         ``back_alpha`` and ``back_beta`` are the posterior parameters for ``rate_back``.  ``fore_scale`` is the scaling
         factor that enforces `rate_fore` > `rate_back`.

         ``px_["r"]`` and ``py_["r"]`` are the inverse dispersion parameters for genes and protein, respectively.
        """
        x_ = x
        y_ = y
        if self.log_variational:
            x_ = torch.log(1 + x_)
            y_ = torch.log(1 + y_)

        # Sampling - Encoder gets concatenated genes + proteins
        qz_m, qz_v, ql_m, ql_v, latent, untran_latent = self.encoder(
            torch.cat((x_, y_), dim=-1), batch_index)
        z = latent["z"]
        library_gene = latent["l"]
        untran_z = untran_latent["z"]
        untran_l = untran_latent["l"]

        if n_samples > 1:
            qz_m = qz_m.unsqueeze(0).expand(
                (n_samples, qz_m.size(0), qz_m.size(1)))
            qz_v = qz_v.unsqueeze(0).expand(
                (n_samples, qz_v.size(0), qz_v.size(1)))
            untran_z = Normal(qz_m, qz_v.sqrt()).sample()
            z = self.encoder.z_transformation(untran_z)
            ql_m = ql_m.unsqueeze(0).expand(
                (n_samples, ql_m.size(0), ql_m.size(1)))
            ql_v = ql_v.unsqueeze(0).expand(
                (n_samples, ql_v.size(0), ql_v.size(1)))
            untran_l = Normal(ql_m, ql_v.sqrt()).sample()
            library_gene = self.encoder.l_transformation(untran_l)

        if self.gene_dispersion == "gene-label":
            # px_r gets transposed - last dimension is nb genes
            px_r = F.linear(one_hot(label, self.n_labels), self.px_r)
        elif self.gene_dispersion == "gene-batch":
            px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r)
        elif self.gene_dispersion == "gene":
            px_r = self.px_r
        px_r = torch.exp(px_r)

        if self.protein_dispersion == "protein-label":
            # py_r gets transposed - last dimension is n_proteins
            py_r = F.linear(one_hot(label, self.n_labels), self.py_r)
        elif self.protein_dispersion == "protein-batch":
            py_r = F.linear(one_hot(batch_index, self.n_batch), self.py_r)
        elif self.protein_dispersion == "protein":
            py_r = self.py_r
        py_r = torch.exp(py_r)

        # Background regularization
        if self.n_batch > 0:
            py_back_alpha_prior = F.linear(one_hot(batch_index, self.n_batch),
                                           self.background_pro_alpha)
            py_back_beta_prior = F.linear(
                one_hot(batch_index, self.n_batch),
                torch.exp(self.background_pro_log_beta),
            )
        else:
            py_back_alpha_prior = self.background_pro_alpha
            py_back_beta_prior = torch.exp(self.background_pro_log_beta)
        self.back_mean_prior = Normal(py_back_alpha_prior, py_back_beta_prior)

        if transform_batch is not None:
            batch_index = torch.ones_like(batch_index) * transform_batch
        px_, py_, log_pro_back_mean = self.decoder(z, library_gene,
                                                   batch_index, label)
        px_["r"] = px_r
        py_["r"] = py_r

        return dict(
            px_=px_,
            py_=py_,
            qz_m=qz_m,
            qz_v=qz_v,
            z=z,
            untran_z=untran_z,
            ql_m=ql_m,
            ql_v=ql_v,
            library_gene=library_gene,
            untran_l=untran_l,
            log_pro_back_mean=log_pro_back_mean,
        )
示例#20
0
 def forward(self, x, o, *os):
     if o.size(1) != self.n_cat:
         o = one_hot(o, self.n_cat)
     for layer in self.fc_layers:
         x = layer(torch.cat((x, o), 1))
     return x
示例#21
0
    def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None):
        is_labelled = False if y is None else True

        xs, ys = (x, y)
        xs_ = torch.log(1 + xs)
        qz1_m, qz1_v, z1_ = self.z_encoder(xs_)
        z1 = z1_
        # Enumerate choices of label
        if not is_labelled:
            ys = enumerate_discrete(xs, self.n_labels)
            xs = xs.repeat(self.n_labels, 1)
            if batch_index is not None:
                batch_index = batch_index.repeat(self.n_labels, 1)
            local_l_var = local_l_var.repeat(self.n_labels, 1)
            local_l_mean = local_l_mean.repeat(self.n_labels, 1)
            qz1_m = qz1_m.repeat(self.n_labels, 1)
            qz1_v = qz1_v.repeat(self.n_labels, 1)
            z1 = z1.repeat(self.n_labels, 1)
        else:
            ys = one_hot(ys, self.n_labels)

        xs_ = torch.log(1 + xs)

        qz2_m, qz2_v, z2 = self.encoder_z2_z1(z1, ys)
        pz1_m, pz1_v = self.decoder_z1_z2(z2, ys)

        # Sampling
        ql_m, ql_v, library = self.l_encoder(xs_)  # let's keep that ind. of y

        px_scale, px_rate, px_dropout = self.decoder(self.dispersion, z1,
                                                     library, batch_index)

        reconst_loss = -log_zinb_positive(xs, px_rate, torch.exp(self.px_r),
                                          px_dropout)

        # KL Divergence
        mean = torch.zeros_like(qz2_m)
        scale = torch.ones_like(qz2_v)

        kl_divergence_z2 = kl(Normal(qz2_m, torch.sqrt(qz2_v)),
                              Normal(mean, scale)).sum(dim=1)
        loss_z1 = (-Normal(pz1_m, torch.sqrt(pz1_v)).log_prob(z1) +
                   Normal(qz1_m, torch.sqrt(qz1_v)).log_prob(z1)).sum(dim=1)
        kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)),
                             Normal(local_l_mean,
                                    torch.sqrt(local_l_var))).sum(dim=1)
        kl_divergence = kl_divergence_z2 + loss_z1 + kl_divergence_l

        if is_labelled:
            return reconst_loss, kl_divergence

        reconst_loss = reconst_loss.view(self.n_labels, -1)
        kl_divergence = kl_divergence.view(self.n_labels, -1)

        probs = self.classifier(z1_)
        reconst_loss = (reconst_loss.t() * probs).sum(dim=1)
        kl_divergence = (kl_divergence.t() * probs).sum(dim=1)

        kl_divergence += kl(Multinomial(probs=probs),
                            Multinomial(probs=self.y_prior))

        return reconst_loss, kl_divergence
示例#22
0
    def inference(self, x, batch_index=None, y=None, n_samples=1):
        x_ = x
        if len(x_) != 2:
            raise ValueError("Input training data should be 2 data types(RNA and ATAC),"
                             "but input was only {}.format(len(x_))"
                             )
        x_rna = x_[0]
        x_atac = x_[1]
        if self.log_variational:
            x_rna = torch.log(1 + x_rna)
            x_atac = torch.log(1 + x_atac)


        # Sampling
        qz_rna_m, qz_rna_v, rna_z = self.RNA_encoder(x_rna, y)
        qz_atac_m, qz_atac_v, atac_z = self.ATAC_encoder(x_atac, y)
        qz_m, qz_v, z = self.RNA_ATAC_encoder([x_rna, x_atac], y)
        gamma, mu_c, var_c, pi = self.get_gamma(z)  # , self.n_centroids, c_params)
        index = torch.argmax(gamma,dim=1)
        #mu_c_max = torch.tensor([])
        #var_c_max = torch.tensor([])

        #for index1 in range(len(index)):
        #    mu_c_max = torch.cat((mu_c_max, mu_c[index1,:,index[index1]].float()),1)
        #    var_c_max = torch.cat((var_c_max, var_c[index1,:,index[index1]].float()),1)

        index1 = [i for i in range(len(index))]
        mu_c_max = mu_c[index1,:,index]
        var_c_max = var_c[index1,:,index]
        z_c_max = reparameterize_gaussian(mu_c_max, var_c_max)

        #decoder
        p_rna_scale, p_rna_r, p_rna_rate, p_rna_dropout, p_atac_scale, p_atac_r, p_atac_mean, p_atac_dropout\
        = self.RNA_ATAC_decoder(z, z_c_max, y)


        if self.dispersion == "gene-label":
            p_rna_r = F.linear(
                one_hot(y, self.n_labels), self.px_r
            )  # px_r gets transposed - last dimension is nb genes
            p_atac_r = F.linear(
                one_hot(y, self.n_labels), self.p_atac_r
            )
        elif self.dispersion == "gene-batch":
            p_rna_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r)
            p_atac_r = F.linear(one_hot(batch_index, self.n_batch), self.p_atac_r)
        elif self.dispersion == "gene":
            p_rna_r = self.px_r
            p_atac_r = self.p_atac_r

        p_rna_r = torch.exp(p_rna_r)
        p_atac_r = torch.exp(p_atac_r)

        return dict(
            p_rna_scale=p_rna_scale,
            p_rna_r=p_rna_r,
            p_rna_rate=p_rna_rate,
            p_rna_dropout=p_rna_dropout,
            p_atac_scale=p_atac_scale,
            p_atac_r=p_atac_r,
            p_atac_mean=p_atac_mean,
            p_atac_dropout=p_atac_dropout,
            qz_rna_m=qz_rna_m,
            qz_rna_v=qz_rna_v,
            rna_z=rna_z,
            qz_atac_m=qz_atac_m,
            qz_atac_v=qz_atac_v,
            atac_z=atac_z,
            qz_m=qz_m,
            qz_v=qz_v,
            z=z,
            mu_c=mu_c,
            var_c=var_c,
            gamma=gamma,
            pi=pi,
            mu_c_max=mu_c_max,
            var_c_max=var_c_max,
            z_c_max=z_c_max,
        )