示例#1
0
    def decode(self, s_arc, s_rel, mask, tree=False, proj=False):
        r"""
        Args:
            s_arc (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
                Scores of all possible arcs.
            s_rel (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
                Scores of all possible labels on each arc.
            mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
                The mask for covering the unpadded tokens.
            tree (bool):
                If ``True``, ensures to output well-formed trees. Default: ``False``.
            proj (bool):
                If ``True``, ensures to output projective trees. Default: ``False``.

        Returns:
            ~torch.LongTensor, ~torch.LongTensor:
                Predicted arcs and labels of shape ``[batch_size, seq_len]``.
        """

        lens = mask.sum(1)
        arc_preds = s_arc.argmax(-1)
        bad = [
            not CoNLL.istree(seq[1:i + 1], proj)
            for i, seq in zip(lens.tolist(), arc_preds.tolist())
        ]
        if tree and any(bad):
            alg = eisner if proj else mst
            arc_preds[bad] = alg(s_arc[bad], mask[bad])
        rel_preds = s_rel.argmax(-1).gather(
            -1, arc_preds.unsqueeze(-1)).squeeze(-1)

        return arc_preds, rel_preds
示例#2
0
 def enumerate(self, semiring):
     trees = []
     for i, length in enumerate(self.lens.tolist()):
         trees.append([])
         for seq in itertools.product(range(length + 1), repeat=length):
             if not CoNLL.istree(list(seq), True, self.multiroot):
                 continue
             trees[-1].append(semiring.prod(self.scores[i, range(1, length + 1), seq], -1))
     return [torch.stack(seq) for seq in trees]
示例#3
0
    def decode(self,
               s_arc,
               s_sib,
               s_rel,
               mask,
               tree=False,
               mbr=True,
               proj=False):
        """
        Args:
            s_arc (torch.Tensor): [batch_size, seq_len, seq_len]
                The scores of all possible arcs.
            s_sib (torch.Tensor): [batch_size, seq_len, seq_len, seq_len]
                The scores of all possible dependent-head-sibling triples.
            s_rel (torch.Tensor): [batch_size, seq_len, seq_len, n_labels]
                The scores of all possible labels on each arc.
            mask (torch.BoolTensor): [batch_size, seq_len, seq_len]
                Mask for covering the unpadded tokens.
            tree (bool):
                If True, ensures to output well-formed trees. Default: False.
            mbr (bool):
                If True, performs MBR decoding. Default: True.
            proj (bool):
                If True, ensures to output projective trees. Default: False.

        Returns:
            arc_preds (torch.Tensor): [batch_size, seq_len]
                The predicted arcs.
            rel_preds (torch.Tensor): [batch_size, seq_len]
                The predicted labels.
        """

        lens = mask.sum(1)
        # prevent self-loops
        s_arc.diagonal(0, 1, 2).fill_(float('-inf'))
        arc_preds = s_arc.argmax(-1)
        bad = [
            not CoNLL.istree(seq[1:i + 1], proj)
            for i, seq in zip(lens.tolist(), arc_preds.tolist())
        ]
        if tree and any(bad):
            if proj and not mbr:
                arc_preds = eisner2o((s_arc, s_sib), mask)
            else:
                alg = eisner if proj else mst
                arc_preds[bad] = alg(s_arc[bad], mask[bad])
        rel_preds = s_rel.argmax(-1).gather(
            -1, arc_preds.unsqueeze(-1)).squeeze(-1)

        return arc_preds, rel_preds
示例#4
0
 def enumerate(self, semiring):
     trees = []
     for i, length in enumerate(self.lens.tolist()):
         trees.append([])
         for seq in itertools.product(range(length + 1), repeat=length):
             if not CoNLL.istree(list(seq), True, self.multiroot):
                 continue
             sibs = self.lens.new_tensor(CoNLL.get_sibs(seq))
             sib_mask = sibs.gt(0)
             s_arc = self.scores[0][i, :length+1, :length+1]
             s_sib = self.scores[1][i, :length+1, :length+1, :length+1]
             s_arc = semiring.prod(s_arc[range(1, length + 1), seq], -1)
             s_sib = semiring.prod(s_sib[1:][sib_mask].gather(-1, sibs[sib_mask].unsqueeze(-1)).squeeze(-1))
             trees[-1].append(semiring.mul(s_arc, s_sib))
     return [torch.stack(seq) for seq in trees]
示例#5
0
    def decode(self,
               s_arc,
               s_sib,
               s_rel,
               mask,
               tree=False,
               mbr=True,
               proj=False):
        r"""
        Args:
            s_arc (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
                Scores of all possible arcs.
            s_sib (~torch.Tensor): ``[batch_size, seq_len, seq_len, seq_len]``.
                Scores of all possible dependent-head-sibling triples.
            s_rel (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
                Scores of all possible labels on each arc.
            mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
                The mask for covering the unpadded tokens.
            tree (bool):
                If ``True``, ensures to output well-formed trees. Default: ``False``.
            mbr (bool):
                If ``True``, performs MBR decoding. Default: ``True``.
            proj (bool):
                If ``True``, ensures to output projective trees. Default: ``False``.

        Returns:
            ~torch.LongTensor, ~torch.LongTensor:
                Predicted arcs and labels of shape ``[batch_size, seq_len]``.
        """

        lens = mask.sum(1)
        arc_preds = s_arc.argmax(-1)
        bad = [
            not CoNLL.istree(seq[1:i + 1], proj)
            for i, seq in zip(lens.tolist(), arc_preds.tolist())
        ]
        if tree and any(bad):
            if proj and not mbr:
                arc_preds[bad] = Dependency2oCRF((s_arc[bad], s_sib[bad]),
                                                 mask[bad].sum(-1)).argmax
            else:
                arc_preds[bad] = (DependencyCRF if proj else MatrixTree)(
                    s_arc[bad], mask[bad].sum(-1)).argmax
        rel_preds = s_rel.argmax(-1).gather(
            -1, arc_preds.unsqueeze(-1)).squeeze(-1)

        return arc_preds, rel_preds