def decode(self, h, mask): # Viterbi decoding # initialize backpointers and viterbi variables in log space bptr = LongTensor() score = Tensor(BATCH_SIZE, self.num_tags).fill_(-10000.) score[:, SOS_IDX] = 0. for t in range(h.size(1)): # recursion through the sequence mask_t = mask[:, t].unsqueeze(1) score_t = score.unsqueeze(1) + self.trans # [B, 1, C] -> [B, C, C] score_t, bptr_t = score_t.max(2) # best previous scores and tags score_t += h[:, t] # plus emission scores bptr = torch.cat((bptr, bptr_t.unsqueeze(1)), 1) score = score_t * mask_t + score * (1 - mask_t) score += self.trans[EOS_IDX] best_score, best_tag = torch.max(score, 1) # back-tracking bptr = bptr.tolist() best_path = [[i] for i in best_tag.tolist()] for b in range(BATCH_SIZE): x = best_tag[b] # best tag y = int(mask[b].sum().item()) for bptr_t in reversed(bptr[b][:y]): x = bptr_t[x] best_path[b].append(x) best_path[b].pop() best_path[b].reverse() return best_path
def decode(self, h, mask): # Viterbi decoding # initialize backpointers and viterbi variables in log space backpointers = LongTensor() batch_size = h.shape[0] delta = Tensor(batch_size, self.num_tags).fill_(-10000.) delta[:, START_TAG_IDX] = 0. # TODO: is adding stop tag within loop needed at all??? # pro argument: yes, backpointers needed at every step - to be checked for t in range(h.size(1)): # iterate through the sequence # backpointers and viterbi variables at this timestep mask_t = mask[:, t].unsqueeze(1) # TODO: maybe unsqueeze transition explicitly for 0 dim for clarity next_tag_var = delta.unsqueeze(1) + self.transition # B x 1 x S + S x S delta_t, backpointers_t = next_tag_var.max(2) backpointers = torch.cat((backpointers, backpointers_t.unsqueeze(1)), 1) delta_next = delta_t + h[:, t] # plus emission scores delta = mask_t * delta_next + (1 - mask_t) * delta # TODO: check correctness # for those that end here add score for transitioning to stop tag if t + 1 < h.size(1): # mask_next = mask[:, t + 1].unsqueeze(1) # ending = mask_next.eq(0.).float().expand(batch_size, self.num_tags) # delta += ending * self.transition[STOP_TAG_IDX].unsqueeze(0) # or ending_here = (mask[:, t].eq(1.) * mask[:, t + 1].eq(0.)).view(1, -1).float() delta += ending_here.transpose(0, 1).mul(self.transition[STOP_TAG_IDX]) # add outer product of two vecs # TODO: check equality of these two again # TODO: should we add transition values for getting in stop state only for those that end here? # TODO: or to all? delta += mask[:, -1].view(1, -1).float().transpose(0, 1).mul(self.transition[STOP_TAG_IDX]) best_score, best_tag = torch.max(delta, 1) # back-tracking backpointers = backpointers.tolist() best_path = [[i] for i in best_tag.tolist()] for idx in range(batch_size): prev_best_tag = best_tag[idx] # best tag id for single instance length = int(scalar(mask[idx].sum())) # length of instance for backpointers_t in reversed(backpointers[idx][:length]): prev_best_tag = backpointers_t[prev_best_tag] best_path[idx].append(prev_best_tag) best_path[idx].pop() # remove start tag best_path[idx].reverse() return best_path