Exemplo n.º 1
0
    def decode(self, h, mask):  # Viterbi decoding
        # initialize backpointers and viterbi variables in log space
        bptr = LongTensor()
        score = Tensor(BATCH_SIZE, self.num_tags).fill_(-10000.)
        score[:, SOS_IDX] = 0.

        for t in range(h.size(1)):  # recursion through the sequence
            mask_t = mask[:, t].unsqueeze(1)
            score_t = score.unsqueeze(1) + self.trans  # [B, 1, C] -> [B, C, C]
            score_t, bptr_t = score_t.max(2)  # best previous scores and tags
            score_t += h[:, t]  # plus emission scores
            bptr = torch.cat((bptr, bptr_t.unsqueeze(1)), 1)
            score = score_t * mask_t + score * (1 - mask_t)
        score += self.trans[EOS_IDX]
        best_score, best_tag = torch.max(score, 1)

        # back-tracking
        bptr = bptr.tolist()
        best_path = [[i] for i in best_tag.tolist()]
        for b in range(BATCH_SIZE):
            x = best_tag[b]  # best tag
            y = int(mask[b].sum().item())
            for bptr_t in reversed(bptr[b][:y]):
                x = bptr_t[x]
                best_path[b].append(x)
            best_path[b].pop()
            best_path[b].reverse()

        return best_path
Exemplo n.º 2
0
    def decode(self, h, mask):  # Viterbi decoding
        # initialize backpointers and viterbi variables in log space
        backpointers = LongTensor()
        batch_size = h.shape[0]
        delta = Tensor(batch_size, self.num_tags).fill_(-10000.)
        delta[:, START_TAG_IDX] = 0.

        # TODO: is adding stop tag within loop needed at all???
        # pro argument: yes, backpointers needed at every step - to be checked
        for t in range(h.size(1)):  # iterate through the sequence
            # backpointers and viterbi variables at this timestep
            mask_t = mask[:, t].unsqueeze(1)
            # TODO: maybe unsqueeze transition explicitly for 0 dim for clarity
            next_tag_var = delta.unsqueeze(1) + self.transition  # B x 1 x S + S x S
            delta_t, backpointers_t = next_tag_var.max(2)
            backpointers = torch.cat((backpointers, backpointers_t.unsqueeze(1)), 1)
            delta_next = delta_t + h[:, t]  # plus emission scores
            delta = mask_t * delta_next + (1 - mask_t) * delta  # TODO: check correctness
            # for those that end here add score for transitioning to stop tag
            if t + 1 < h.size(1):
                # mask_next = mask[:, t + 1].unsqueeze(1)
                # ending = mask_next.eq(0.).float().expand(batch_size, self.num_tags)
                # delta += ending * self.transition[STOP_TAG_IDX].unsqueeze(0)
                # or
                ending_here = (mask[:, t].eq(1.) * mask[:, t + 1].eq(0.)).view(1, -1).float()
                delta += ending_here.transpose(0, 1).mul(self.transition[STOP_TAG_IDX])  # add outer product of two vecs
                # TODO: check equality of these two again

        # TODO: should we add transition values for getting in stop state only for those that end here?
        # TODO: or to all?
        delta += mask[:, -1].view(1, -1).float().transpose(0, 1).mul(self.transition[STOP_TAG_IDX])
        best_score, best_tag = torch.max(delta, 1)

        # back-tracking
        backpointers = backpointers.tolist()
        best_path = [[i] for i in best_tag.tolist()]
        for idx in range(batch_size):
            prev_best_tag = best_tag[idx]  # best tag id for single instance
            length = int(scalar(mask[idx].sum()))  # length of instance
            for backpointers_t in reversed(backpointers[idx][:length]):
                prev_best_tag = backpointers_t[prev_best_tag]
                best_path[idx].append(prev_best_tag)
            best_path[idx].pop()  # remove start tag
            best_path[idx].reverse()

        return best_path