Exemple #1
0
    def traceback(self, x, order):
        """ Generate alignment matrix.

        Parameters
        ----------
        x : PackedSequence
            Packed sequence object of proteins to align.
        order : np.array
            The origin order of the sequences

        Returns
        -------
        decoded : list of tuple
            State string representing alignment coordinates
        aln : torch.Tensor
            Alignment Matrix (dim B x N x M)
        """
        # dim B x N x D
        with torch.enable_grad():
            zx, _, zy, _ = unpack_sequences(self.match_embedding(x), order)
            gx, xlen, gy, ylen = unpack_sequences(self.gap_embedding(x), order)
            match = F.softplus(torch.einsum('bid,bjd->bij', zx, zy))
            gap = F.logsigmoid(torch.einsum('bid,bjd->bij', gx, gy))
            B, _, _ = match.shape
            for b in range(B):
                aln = self.nw.decode(match[b, :xlen[b], :ylen[b]].unsqueeze(0),
                                     gap[b, :xlen[b], :ylen[b]].unsqueeze(0))
                decoded = self.nw.traceback(aln.squeeze())
                yield decoded, aln
Exemple #2
0
    def forward(self, x, order):
        """ Generate alignment matrix.

        Parameters
        ----------
        x : PackedSequence
            Packed sequence object of proteins to align.
        order : np.array
            The origin order of the sequences

        Returns
        -------
        aln : torch.Tensor
            Alignment Matrix (dim B x N x M)
        """
        with torch.enable_grad():
            zx, _, zy, _ = unpack_sequences(self.match_embedding(x), order)
            gx, _, gy, _ = unpack_sequences(self.gap_embedding(x), order)

            # Obtain theta through an inner product across latent dimensions
            theta = F.softplus(torch.einsum('bid,bjd->bij', zx, zy))
            A = F.logsigmoid(torch.einsum('bid,bjd->bij', gx, gy))
            theta = theta.squeeze(0)
            A = A.squeeze(0)
            aln = self.nw.decode(theta, A)
            return aln, theta, A
Exemple #3
0
    def score(self, x, order):
        with torch.no_grad():
            zx, _, zy, _ = unpack_sequences(self.match_embedding(x), order)
            gx, _, gy, _ = unpack_sequences(self.gap_embedding(x), order)

            # Obtain theta through an inner product across latent dimensions
            theta = F.softplus(torch.einsum('bid,bjd->bij', zx, zy))
            A = F.logsigmoid(torch.einsum('bid,bjd->bij', gx, gy))
            ascore = self.nw(theta, A)
            return ascore
Exemple #4
0
 def traceback(self, x, order):
     # dim B x N x D
     with torch.enable_grad():
         zx, _, zy, _ = unpack_sequences(self.match_embedding(x), order)
         gx, xlen, gy, ylen = unpack_sequences(self.gap_embedding(x), order)
         match = F.softplus(torch.einsum('bid,bjd->bij', zx, zy))
         gap = F.logsigmoid(torch.einsum('bid,bjd->bij', gx, gy))
         B, _, _ = match.shape
         for b in range(B):
             aln = self.nw.decode(match[b, :xlen[b], :ylen[b]].unsqueeze(0),
                                  gap[b, :xlen[b], :ylen[b]].unsqueeze(0))
             decoded = self.nw.traceback(aln.squeeze())
             yield decoded, aln
Exemple #5
0
 def test_step(self, batch, batch_idx):
     genes, others, s, A, P, G, gene_names, other_names = batch
     seq, order = pack_sequences(genes, others)
     predA, theta, gap = self.aligner(seq, order)
     x, xlen, y, ylen = unpack_sequences(seq, order)
     loss = self.compute_loss(xlen, ylen, predA, A, P, G, theta)
     assert torch.isnan(loss).item() is False
     # Obtain alignment statistics + visualizations
     gen = self.aligner.traceback(seq, order)
     # TODO: compare the traceback and the forward
     statistics = self.validation_stats(x, y, xlen, ylen, gen, s, A, predA,
                                        theta, gap, batch_idx)
     assert len(statistics) > 0, (batch_idx, s)
     genes = list(
         map(
             lambda x: self.tokenizer.alphabet.decode(x.detach().cpu(
             ).numpy()).decode("utf-8"), genes))
     others = list(
         map(
             lambda x: self.tokenizer.alphabet.decode(x.detach().cpu(
             ).numpy()).decode("utf-8"), others))
     statistics = pd.DataFrame(statistics,
                               columns=[
                                   'test_tp', 'test_fp', 'test_fn',
                                   'test_perc_id', 'test_ppv', 'test_fnr',
                                   'test_fdr'
                               ])
     statistics['query_name'] = gene_names
     statistics['key_name'] = other_names
     return statistics
Exemple #6
0
 def test_unpack_sequences(self):
     X = [torch.Tensor([6, 4, 5]), torch.Tensor([1, 4, 5, 7])]
     Y = [
         torch.Tensor([21, 10, 12, 2, 4, 5]),
         torch.Tensor([1, 4, 11, 13, 14])
     ]
     z, order = pack_sequences(X, Y)
     resX, xlen, resY, ylen = unpack_sequences(z, order)
     tt.assert_allclose(xlen, torch.Tensor([3, 4]).long())
     tt.assert_allclose(ylen, torch.Tensor([6, 5]).long())
     expX = torch.Tensor([[6, 4, 5, 0, 0, 0], [1, 4, 5, 7, 0, 0]])
     expY = torch.Tensor([[21, 10, 12, 2, 4, 5], [1, 4, 11, 13, 14, 0]])
     tt.assert_allclose(expX, resX)
     tt.assert_allclose(expY, resY)
Exemple #7
0
 def training_step(self, batch, batch_idx):
     self.aligner.train()
     genes, others, s, A, P, G = batch
     seq, order = pack_sequences(genes, others)
     predA, theta, gap = self.aligner(seq, order)
     _, xlen, _, ylen = unpack_sequences(seq, order)
     loss = self.compute_loss(xlen, ylen, predA, A, P, G, theta)
     assert torch.isnan(loss).item() is False
     if len(self.trainer.lr_schedulers) >= 1:
         current_lr = self.trainer.lr_schedulers[0]['scheduler']
         current_lr = current_lr.get_last_lr()[0]
     else:
         current_lr = self.hparams.learning_rate
     tensorboard_logs = {'train_loss': loss, 'lr': current_lr}
     # log the learning rate
     return {'loss': loss, 'log': tensorboard_logs}
Exemple #8
0
 def validation_step(self, batch, batch_idx):
     genes, others, s, A, P, G = batch
     seq, order = pack_sequences(genes, others)
     predA, theta, gap = self.aligner(seq, order)
     x, xlen, y, ylen = unpack_sequences(seq, order)
     loss = self.compute_loss(xlen, ylen, predA, A, P, G, theta)
     assert torch.isnan(loss).item() is False
     # Obtain alignment statistics + visualizations
     gen = self.aligner.traceback(seq, order)
     # TODO; compare the traceback and the forward
     statistics = self.validation_stats(x, y, xlen, ylen, gen, s, A, predA,
                                        theta, gap, batch_idx)
     statistics = pd.DataFrame(statistics,
                               columns=[
                                   'val_tp', 'val_fp', 'val_fn',
                                   'val_perc_id', 'val_ppv', 'val_fnr',
                                   'val_fdr'
                               ])
     statistics = statistics.mean(axis=0).to_dict()
     tensorboard_logs = {'valid_loss': loss}
     tensorboard_logs = {**tensorboard_logs, **statistics}
     return {'validation_loss': loss, 'log': tensorboard_logs}