def traceback(self, x, order): """ Generate alignment matrix. Parameters ---------- x : PackedSequence Packed sequence object of proteins to align. order : np.array The origin order of the sequences Returns ------- decoded : list of tuple State string representing alignment coordinates aln : torch.Tensor Alignment Matrix (dim B x N x M) """ # dim B x N x D with torch.enable_grad(): zx, _, zy, _ = unpack_sequences(self.match_embedding(x), order) gx, xlen, gy, ylen = unpack_sequences(self.gap_embedding(x), order) match = F.softplus(torch.einsum('bid,bjd->bij', zx, zy)) gap = F.logsigmoid(torch.einsum('bid,bjd->bij', gx, gy)) B, _, _ = match.shape for b in range(B): aln = self.nw.decode(match[b, :xlen[b], :ylen[b]].unsqueeze(0), gap[b, :xlen[b], :ylen[b]].unsqueeze(0)) decoded = self.nw.traceback(aln.squeeze()) yield decoded, aln
def forward(self, x, order): """ Generate alignment matrix. Parameters ---------- x : PackedSequence Packed sequence object of proteins to align. order : np.array The origin order of the sequences Returns ------- aln : torch.Tensor Alignment Matrix (dim B x N x M) """ with torch.enable_grad(): zx, _, zy, _ = unpack_sequences(self.match_embedding(x), order) gx, _, gy, _ = unpack_sequences(self.gap_embedding(x), order) # Obtain theta through an inner product across latent dimensions theta = F.softplus(torch.einsum('bid,bjd->bij', zx, zy)) A = F.logsigmoid(torch.einsum('bid,bjd->bij', gx, gy)) theta = theta.squeeze(0) A = A.squeeze(0) aln = self.nw.decode(theta, A) return aln, theta, A
def score(self, x, order): with torch.no_grad(): zx, _, zy, _ = unpack_sequences(self.match_embedding(x), order) gx, _, gy, _ = unpack_sequences(self.gap_embedding(x), order) # Obtain theta through an inner product across latent dimensions theta = F.softplus(torch.einsum('bid,bjd->bij', zx, zy)) A = F.logsigmoid(torch.einsum('bid,bjd->bij', gx, gy)) ascore = self.nw(theta, A) return ascore
def traceback(self, x, order): # dim B x N x D with torch.enable_grad(): zx, _, zy, _ = unpack_sequences(self.match_embedding(x), order) gx, xlen, gy, ylen = unpack_sequences(self.gap_embedding(x), order) match = F.softplus(torch.einsum('bid,bjd->bij', zx, zy)) gap = F.logsigmoid(torch.einsum('bid,bjd->bij', gx, gy)) B, _, _ = match.shape for b in range(B): aln = self.nw.decode(match[b, :xlen[b], :ylen[b]].unsqueeze(0), gap[b, :xlen[b], :ylen[b]].unsqueeze(0)) decoded = self.nw.traceback(aln.squeeze()) yield decoded, aln
def test_step(self, batch, batch_idx): genes, others, s, A, P, G, gene_names, other_names = batch seq, order = pack_sequences(genes, others) predA, theta, gap = self.aligner(seq, order) x, xlen, y, ylen = unpack_sequences(seq, order) loss = self.compute_loss(xlen, ylen, predA, A, P, G, theta) assert torch.isnan(loss).item() is False # Obtain alignment statistics + visualizations gen = self.aligner.traceback(seq, order) # TODO: compare the traceback and the forward statistics = self.validation_stats(x, y, xlen, ylen, gen, s, A, predA, theta, gap, batch_idx) assert len(statistics) > 0, (batch_idx, s) genes = list( map( lambda x: self.tokenizer.alphabet.decode(x.detach().cpu( ).numpy()).decode("utf-8"), genes)) others = list( map( lambda x: self.tokenizer.alphabet.decode(x.detach().cpu( ).numpy()).decode("utf-8"), others)) statistics = pd.DataFrame(statistics, columns=[ 'test_tp', 'test_fp', 'test_fn', 'test_perc_id', 'test_ppv', 'test_fnr', 'test_fdr' ]) statistics['query_name'] = gene_names statistics['key_name'] = other_names return statistics
def test_unpack_sequences(self): X = [torch.Tensor([6, 4, 5]), torch.Tensor([1, 4, 5, 7])] Y = [ torch.Tensor([21, 10, 12, 2, 4, 5]), torch.Tensor([1, 4, 11, 13, 14]) ] z, order = pack_sequences(X, Y) resX, xlen, resY, ylen = unpack_sequences(z, order) tt.assert_allclose(xlen, torch.Tensor([3, 4]).long()) tt.assert_allclose(ylen, torch.Tensor([6, 5]).long()) expX = torch.Tensor([[6, 4, 5, 0, 0, 0], [1, 4, 5, 7, 0, 0]]) expY = torch.Tensor([[21, 10, 12, 2, 4, 5], [1, 4, 11, 13, 14, 0]]) tt.assert_allclose(expX, resX) tt.assert_allclose(expY, resY)
def training_step(self, batch, batch_idx): self.aligner.train() genes, others, s, A, P, G = batch seq, order = pack_sequences(genes, others) predA, theta, gap = self.aligner(seq, order) _, xlen, _, ylen = unpack_sequences(seq, order) loss = self.compute_loss(xlen, ylen, predA, A, P, G, theta) assert torch.isnan(loss).item() is False if len(self.trainer.lr_schedulers) >= 1: current_lr = self.trainer.lr_schedulers[0]['scheduler'] current_lr = current_lr.get_last_lr()[0] else: current_lr = self.hparams.learning_rate tensorboard_logs = {'train_loss': loss, 'lr': current_lr} # log the learning rate return {'loss': loss, 'log': tensorboard_logs}
def validation_step(self, batch, batch_idx): genes, others, s, A, P, G = batch seq, order = pack_sequences(genes, others) predA, theta, gap = self.aligner(seq, order) x, xlen, y, ylen = unpack_sequences(seq, order) loss = self.compute_loss(xlen, ylen, predA, A, P, G, theta) assert torch.isnan(loss).item() is False # Obtain alignment statistics + visualizations gen = self.aligner.traceback(seq, order) # TODO; compare the traceback and the forward statistics = self.validation_stats(x, y, xlen, ylen, gen, s, A, predA, theta, gap, batch_idx) statistics = pd.DataFrame(statistics, columns=[ 'val_tp', 'val_fp', 'val_fn', 'val_perc_id', 'val_ppv', 'val_fnr', 'val_fdr' ]) statistics = statistics.mean(axis=0).to_dict() tensorboard_logs = {'valid_loss': loss} tensorboard_logs = {**tensorboard_logs, **statistics} return {'validation_loss': loss, 'log': tensorboard_logs}