def inside(self, scores, mask):
        lens = mask[:, 0].sum(-1)
        batch_size, seq_len, _ = scores.shape
        # [seq_len, seq_len, batch_size]
        scores, mask = scores.permute(1, 2, 0), mask.permute(1, 2, 0)
        s = torch.full_like(scores, float('-inf'))

        for w in range(1, seq_len):
            # n denotes the number of spans to iterate,
            # from span (0, w) to span (n, n+w) given width w
            n = seq_len - w

            if w == 1:
                s.diagonal(w).copy_(scores.diagonal(w))
                continue
            # [n, w, batch_size]
            s_s = stripe(s, n, w - 1, (0, 1)) + stripe(s, n, w - 1, (1, w), 0)
            # [batch_size, n, w]
            s_s = s_s.permute(2, 0, 1)
            if s_s.requires_grad:
                s_s.register_hook(lambda x: x.masked_fill_(torch.isnan(x), 0))
            s_s = s_s.logsumexp(-1)
            s.diagonal(w).copy_(s_s + scores.diagonal(w))

        return s[0].gather(0, lens.unsqueeze(0)).sum()
예제 #2
0
def cky(scores, mask):
    """
    The implementation of Cocke-Kasami-Younger (CKY) algorithm to parse constituency trees.

    References:
        - Yu Zhang, Houquan Zhou and Zhenghua Li (IJCAI'20)
          Fast and Accurate Neural CRF Constituency Parsing
          https://www.ijcai.org/Proceedings/2020/560/

    Args:
        scores (torch.Tensor): [batch_size seq_len, seq_len]
            The scores of all candidate constituents.
        mask (torch.BoolTensor): [batch_size, seq_len, seq_len]
            Mask to avoid parsing over padding tokens.
            For each square matrix in a batch, the positions except upper triangular part should be masked out.

    Returns:
        trees (list[list[tuple]]):
            The sequences of factorized predicted bracketed trees traversed in pre-order.
    """

    lens = mask[:, 0].sum(-1)
    scores = scores.permute(1, 2, 0)
    seq_len, seq_len, batch_size = scores.shape
    s = scores.new_zeros(seq_len, seq_len, batch_size)
    p = scores.new_zeros(seq_len, seq_len, batch_size).long()

    for w in range(1, seq_len):
        n = seq_len - w
        starts = p.new_tensor(range(n)).unsqueeze(0)

        if w == 1:
            s.diagonal(w).copy_(scores.diagonal(w))
            continue
        # [n, w, batch_size]
        s_span = stripe(s, n, w - 1, (0, 1)) + stripe(s, n, w - 1, (1, w), 0)
        # [batch_size, n, w]
        s_span = s_span.permute(2, 0, 1)
        # [batch_size, n]
        s_span, p_span = s_span.max(-1)
        s.diagonal(w).copy_(s_span + scores.diagonal(w))
        p.diagonal(w).copy_(p_span + starts + 1)

    def backtrack(p, i, j):
        if j == i + 1:
            return [(i, j)]
        split = p[i][j]
        ltree = backtrack(p, i, split)
        rtree = backtrack(p, split, j)
        return [(i, j)] + ltree + rtree

    p = p.permute(2, 0, 1).tolist()
    trees = [
        backtrack(p[i], 0, length) for i, length in enumerate(lens.tolist())
    ]

    return trees
예제 #3
0
파일: tree.py 프로젝트: yzhangcs/parser
    def forward(self, semiring):
        batch_size, seq_len = self.scores.shape[:2]
        # [seq_len, seq_len, batch_size, ...], (l->r)
        scores = semiring.convert(self.scores.movedim((1, 2), (0, 1)))
        s = semiring.zeros_like(scores)
        s.diagonal(1).copy_(scores.diagonal(1))

        for w in range(2, seq_len):
            n = seq_len - w
            # [n, batch_size, ...]
            s_s = semiring.dot(stripe(s, n, w - 1, (0, 1)),
                               stripe(s, n, w - 1, (1, w), False), 1)
            s.diagonal(w).copy_(
                semiring.mul(s_s,
                             scores.diagonal(w).movedim(-1, 0)).movedim(0, -1))
        return semiring.unconvert(s)[0][self.lens, range(batch_size)]
예제 #4
0
파일: tree.py 프로젝트: yzhangcs/parser
    def forward(self, semiring):
        s_dep, s_con = self.scores[0], self.scores[1]
        batch_size, seq_len, *_ = s_con.shape
        # [seq_len, seq_len, batch_size, ...], (m<-h)
        s_dep = semiring.convert(s_dep.movedim(0, 2))
        s_root, s_dep = s_dep[1:, 0], s_dep[1:, 1:]
        # [seq_len, seq_len, batch_size, ...], (l->r)
        s_con = semiring.convert(s_con.movedim((1, 2), (0, 1)))
        # [seq_len, seq_len, seq_len, batch_size, ...], (i, j, h)
        s_span = semiring.zero_(
            s_con.new_empty(seq_len, seq_len, seq_len - 1, *s_con.shape[2:]))
        # [seq_len, seq_len, seq_len, batch_size, ...], (i, j<-h)
        s_hook = semiring.zero_(
            s_con.new_empty(seq_len, seq_len, seq_len - 1, *s_con.shape[2:]))
        diagonal_stripe(s_span, 1).copy_(
            s_con.diagonal(1).movedim(-1, 0).unsqueeze(1))
        s_hook.diagonal(1).copy_(
            semiring.mul(s_dep,
                         s_con.diagonal(1).movedim(-1,
                                                   0).unsqueeze(1)).movedim(
                                                       0, -1))

        for w in range(2, seq_len):
            n = seq_len - w
            # COMPLETE-L: s_span_l(i, j, h) = <s_span(i, k, h), s_hook(h->k, j)>, i < k < j
            # [n, w, batch_size, ...]
            s_l = stripe(
                semiring.dot(stripe(s_span, n, w - 1, (0, 1)),
                             stripe(s_hook, n, w - 1, (1, w), False), 1), n, w)
            # COMPLETE-R: s_span_r(i, j, h) = <s_hook(i, k<-h), s_span(k, j, h)>, i < k < j
            # [n, w, batch_size, ...]
            s_r = stripe(
                semiring.dot(stripe(s_hook, n, w - 1, (0, 1)),
                             stripe(s_span, n, w - 1, (1, w), False), 1), n, w)
            # COMPLETE: s_span(i, j, h) = s_span_l(i, j, h) + s_span_r(i, j, h) + s(i, j)
            # [n, w, batch_size, ...]
            s = semiring.mul(semiring.sum(torch.stack((s_l, s_r)), 0),
                             s_con.diagonal(w).movedim(-1, 0).unsqueeze(1))
            diagonal_stripe(s_span, w).copy_(s)

            if w == seq_len - 1:
                continue
            # ATTACH: s_hook(h->i, j) = <s(h->m), s_span(i, j, m)>, i <= m < j
            # [n, seq_len, batch_size, ...]
            s = semiring.dot(expanded_stripe(s_dep, n, w),
                             diagonal_stripe(s_span, w).unsqueeze(2), 1)
            s_hook.diagonal(w).copy_(s.movedim(0, -1))
        return semiring.unconvert(
            semiring.dot(
                s_span[0][self.lens, :, range(batch_size)].transpose(0, 1),
                s_root, 0))
    def inside(self, scores, mask, cands=None):
        # the end position of each sentence in a batch
        lens = mask.sum(1)
        batch_size, seq_len, _ = scores.shape
        # [seq_len, seq_len, batch_size]
        scores = scores.permute(2, 1, 0)
        s_i = torch.full_like(scores, float('-inf'))
        s_c = torch.full_like(scores, float('-inf'))
        s_c.diagonal().fill_(0)

        # set the scores of arcs excluded by cands to -inf
        if cands is not None:
            mask = mask.index_fill(1, lens.new_tensor(0), 1)
            mask = (mask.unsqueeze(1) & mask.unsqueeze(-1)).permute(2, 1, 0)
            cands = cands.unsqueeze(-1).index_fill(1, lens.new_tensor(0), -1)
            cands = cands.eq(lens.new_tensor(range(seq_len))) | cands.lt(0)
            cands = cands.permute(2, 1, 0) & mask
            scores = scores.masked_fill(~cands, float('-inf'))

        for w in range(1, seq_len):
            # n denotes the number of spans to iterate,
            # from span (0, w) to span (n, n+w) given width w
            n = seq_len - w

            # ilr = C(i->r) + C(j->r+1)
            # [n, w, batch_size]
            ilr = stripe(s_c, n, w) + stripe(s_c, n, w, (w, 1))
            if ilr.requires_grad:
                ilr.register_hook(lambda x: x.masked_fill_(torch.isnan(x), 0))
            il = ir = ilr.permute(2, 0, 1).logsumexp(-1)
            # I(j->i) = logsumexp(C(i->r) + C(j->r+1)) + s(j->i), i <= r < j
            # fill the w-th diagonal of the lower triangular part of s_i
            # with I(j->i) of n spans
            s_i.diagonal(-w).copy_(il + scores.diagonal(-w))
            # I(i->j) = logsumexp(C(i->r) + C(j->r+1)) + s(i->j), i <= r < j
            # fill the w-th diagonal of the upper triangular part of s_i
            # with I(i->j) of n spans
            s_i.diagonal(w).copy_(ir + scores.diagonal(w))

            # C(j->i) = logsumexp(C(r->i) + I(j->r)), i <= r < j
            cl = stripe(s_c, n, w, (0, 0), 0) + stripe(s_i, n, w, (w, 0))
            cl.register_hook(lambda x: x.masked_fill_(torch.isnan(x), 0))
            s_c.diagonal(-w).copy_(cl.permute(2, 0, 1).logsumexp(-1))
            # C(i->j) = logsumexp(I(i->r) + C(r->j)), i < r <= j
            cr = stripe(s_i, n, w, (0, 1)) + stripe(s_c, n, w, (1, w), 0)
            cr.register_hook(lambda x: x.masked_fill_(torch.isnan(x), 0))
            s_c.diagonal(w).copy_(cr.permute(2, 0, 1).logsumexp(-1))
            # disable multi words to modify the root
            s_c[0, w][lens.ne(w)] = float('-inf')

        return s_c[0].gather(0, lens.unsqueeze(0)).sum()
예제 #6
0
파일: tree.py 프로젝트: yzhangcs/parser
    def forward(self, semiring):
        s_arc = self.scores
        batch_size, seq_len = s_arc.shape[:2]
        # [seq_len, seq_len, batch_size, ...], (h->m)
        s_arc = semiring.convert(s_arc.movedim((1, 2), (1, 0)))
        s_i = semiring.zeros_like(s_arc)
        s_c = semiring.zeros_like(s_arc)
        semiring.one_(s_c.diagonal().movedim(-1, 1))

        for w in range(1, seq_len):
            n = seq_len - w

            # [n, batch_size, ...]
            il = ir = semiring.dot(stripe(s_c, n, w),
                                   stripe(s_c, n, w, (w, 1)), 1)
            # INCOMPLETE-L: I(j->i) = <C(i->r), C(j->r+1)> * s(j->i), i <= r < j
            # fill the w-th diagonal of the lower triangular part of s_i with I(j->i) of n spans
            s_i.diagonal(-w).copy_(
                semiring.mul(il,
                             s_arc.diagonal(-w).movedim(-1, 0)).movedim(0, -1))
            # INCOMPLETE-R: I(i->j) = <C(i->r), C(j->r+1)> * s(i->j), i <= r < j
            # fill the w-th diagonal of the upper triangular part of s_i with I(i->j) of n spans
            s_i.diagonal(w).copy_(
                semiring.mul(ir,
                             s_arc.diagonal(w).movedim(-1, 0)).movedim(0, -1))

            # [n, batch_size, ...]
            # COMPLETE-L: C(j->i) = <C(r->i), I(j->r)>, i <= r < j
            cl = semiring.dot(stripe(s_c, n, w, (0, 0), 0),
                              stripe(s_i, n, w, (w, 0)), 1)
            s_c.diagonal(-w).copy_(cl.movedim(0, -1))
            # COMPLETE-R: C(i->j) = <I(i->r), C(r->j)>, i < r <= j
            cr = semiring.dot(stripe(s_i, n, w, (0, 1)),
                              stripe(s_c, n, w, (1, w), 0), 1)
            s_c.diagonal(w).copy_(cr.movedim(0, -1))
            if not self.multiroot:
                s_c[0, w][self.lens.ne(w)] = semiring.zero
        return semiring.unconvert(s_c)[0][self.lens, range(batch_size)]
예제 #7
0
def cky(scores, mask):
    r"""
    The implementation of `Cocke-Kasami-Younger`_ (CKY) algorithm to parse constituency trees.

    References:
        - Yu Zhang, Houquan Zhou and Zhenghua Li. 2020.
          `Fast and Accurate Neural CRF Constituency Parsing`_.

    Args:
        scores (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
            Scores of all candidate constituents.
        mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``.
            The mask to avoid parsing over padding tokens.
            For each square matrix in a batch, the positions except upper triangular part should be masked out.

    Returns:
        Sequences of factorized predicted bracketed trees that are traversed in pre-order.

    Examples:
        >>> scores = torch.tensor([[[ 2.5659,  1.4253, -2.5272,  3.3011],
                                    [ 1.3687, -0.5869,  1.0011,  3.3020],
                                    [ 1.2297,  0.4862,  1.1975,  2.5387],
                                    [-0.0511, -1.2541, -0.7577,  0.2659]]])
        >>> mask = torch.tensor([[[False,  True,  True,  True],
                                  [False, False,  True,  True],
                                  [False, False, False,  True],
                                  [False, False, False, False]]])
        >>> cky(scores, mask)
        [[(0, 3), (0, 1), (1, 3), (1, 2), (2, 3)]]

    .. _Cocke-Kasami-Younger:
        https://en.wikipedia.org/wiki/CYK_algorithm
    .. _Fast and Accurate Neural CRF Constituency Parsing:
        https://www.ijcai.org/Proceedings/2020/560/
    """

    lens = mask[:, 0].sum(-1)
    scores = scores.permute(1, 2, 0)
    seq_len, seq_len, batch_size = scores.shape
    s = scores.new_zeros(seq_len, seq_len, batch_size)
    p = scores.new_zeros(seq_len, seq_len, batch_size).long()

    for w in range(1, seq_len):
        n = seq_len - w
        starts = p.new_tensor(range(n)).unsqueeze(0)

        if w == 1:
            s.diagonal(w).copy_(scores.diagonal(w))
            continue
        # [n, w, batch_size]
        s_span = stripe(s, n, w - 1, (0, 1)) + stripe(s, n, w - 1, (1, w), 0)
        # [batch_size, n, w]
        s_span = s_span.permute(2, 0, 1)
        # [batch_size, n]
        s_span, p_span = s_span.max(-1)
        s.diagonal(w).copy_(s_span + scores.diagonal(w))
        p.diagonal(w).copy_(p_span + starts + 1)

    def backtrack(p, i, j):
        if j == i + 1:
            return [(i, j)]
        split = p[i][j]
        ltree = backtrack(p, i, split)
        rtree = backtrack(p, split, j)
        return [(i, j)] + ltree + rtree

    p = p.permute(2, 0, 1).tolist()
    trees = [
        backtrack(p[i], 0, length) for i, length in enumerate(lens.tolist())
    ]

    return trees
예제 #8
0
def eisner2o(scores, mask):
    r"""
    Second-order Eisner algorithm for projective decoding.
    This is an extension of the first-order one that further incorporates sibling scores into tree scoring.

    References:
        - Ryan McDonald and Fernando Pereira. 2006.
          `Online Learning of Approximate Dependency Parsing Algorithms`_.

    Args:
        scores (~torch.Tensor, ~torch.Tensor):
            A tuple of two tensors representing the first-order and second-order scores repectively.
            The first (``[batch_size, seq_len, seq_len]``) holds scores of all dependent-head pairs.
            The second (``[batch_size, seq_len, seq_len, seq_len]``) holds scores of all dependent-head-sibling triples.
        mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
            The mask to avoid parsing over padding tokens.
            The first column serving as pseudo words for roots should be ``False``.

    Returns:
        ~torch.Tensor:
            A tensor with shape ``[batch_size, seq_len]`` for the resulting projective parse trees.

    Examples:
        >>> s_arc = torch.tensor([[[ -2.8092,  -7.9104,  -0.9414,  -5.4360],
                                   [-10.3494,  -7.9298,  -3.6929,  -7.3985],
                                   [  1.1815,  -3.8291,   2.3166,  -2.7183],
                                   [ -3.9776,  -3.9063,  -1.6762,  -3.1861]]])
        >>> s_sib = torch.tensor([[[[ 0.4719,  0.4154,  1.1333,  0.6946],
                                    [ 1.1252,  1.3043,  2.1128,  1.4621],
                                    [ 0.5974,  0.5635,  1.0115,  0.7550],
                                    [ 1.1174,  1.3794,  2.2567,  1.4043]],
                                   [[-2.1480, -4.1830, -2.5519, -1.8020],
                                    [-1.2496, -1.7859, -0.0665, -0.4938],
                                    [-2.6171, -4.0142, -2.9428, -2.2121],
                                    [-0.5166, -1.0925,  0.5190,  0.1371]],
                                   [[ 0.5827, -1.2499, -0.0648, -0.0497],
                                    [ 1.4695,  0.3522,  1.5614,  1.0236],
                                    [ 0.4647, -0.7996, -0.3801,  0.0046],
                                    [ 1.5611,  0.3875,  1.8285,  1.0766]],
                                   [[-1.3053, -2.9423, -1.5779, -1.2142],
                                    [-0.1908, -0.9699,  0.3085,  0.1061],
                                    [-1.6783, -2.8199, -1.8853, -1.5653],
                                    [ 0.3629, -0.3488,  0.9011,  0.5674]]]])
        >>> mask = torch.tensor([[False,  True,  True,  True]])
        >>> eisner2o((s_arc, s_sib), mask)
        tensor([[0, 2, 0, 2]])

    .. _Online Learning of Approximate Dependency Parsing Algorithms:
        https://www.aclweb.org/anthology/E06-1011/
    """

    # the end position of each sentence in a batch
    lens = mask.sum(1)
    s_arc, s_sib = scores
    batch_size, seq_len, _ = s_arc.shape
    # [seq_len, seq_len, batch_size]
    s_arc = s_arc.permute(2, 1, 0)
    # [seq_len, seq_len, seq_len, batch_size]
    s_sib = s_sib.permute(2, 1, 3, 0)
    s_i = torch.full_like(s_arc, float('-inf'))
    s_s = torch.full_like(s_arc, float('-inf'))
    s_c = torch.full_like(s_arc, float('-inf'))
    p_i = s_arc.new_zeros(seq_len, seq_len, batch_size).long()
    p_s = s_arc.new_zeros(seq_len, seq_len, batch_size).long()
    p_c = s_arc.new_zeros(seq_len, seq_len, batch_size).long()
    s_c.diagonal().fill_(0)

    for w in range(1, seq_len):
        # n denotes the number of spans to iterate,
        # from span (0, w) to span (n, n+w) given width w
        n = seq_len - w
        starts = p_i.new_tensor(range(n)).unsqueeze(0)
        # I(j->i) = max(I(j->r) + S(j->r, i)), i < r < j |
        #               C(j->j) + C(i->j-1))
        #           + s(j->i)
        # [n, w, batch_size]
        il = stripe(s_i, n, w, (w, 1)) + stripe(s_s, n, w, (1, 0), 0)
        il += stripe(s_sib[range(w, n + w), range(n)], n, w, (0, 1))
        # [n, 1, batch_size]
        il0 = stripe(s_c, n, 1, (w, w)) + stripe(s_c, n, 1, (0, w - 1))
        # il0[0] are set to zeros since the scores of the complete spans starting from 0 are always -inf
        il[:, -1] = il0.index_fill_(0, lens.new_tensor(0), 0).squeeze(1)
        il_span, il_path = il.permute(2, 0, 1).max(-1)
        s_i.diagonal(-w).copy_(il_span + s_arc.diagonal(-w))
        p_i.diagonal(-w).copy_(il_path + starts + 1)
        # I(i->j) = max(I(i->r) + S(i->r, j), i < r < j |
        #               C(i->i) + C(j->i+1))
        #           + s(i->j)
        # [n, w, batch_size]
        ir = stripe(s_i, n, w) + stripe(s_s, n, w, (0, w), 0)
        ir += stripe(s_sib[range(n), range(w, n + w)], n, w)
        ir[0] = float('-inf')
        # [n, 1, batch_size]
        ir0 = stripe(s_c, n, 1) + stripe(s_c, n, 1, (w, 1))
        ir[:, 0] = ir0.squeeze(1)
        ir_span, ir_path = ir.permute(2, 0, 1).max(-1)
        s_i.diagonal(w).copy_(ir_span + s_arc.diagonal(w))
        p_i.diagonal(w).copy_(ir_path + starts)

        # [n, w, batch_size]
        slr = stripe(s_c, n, w) + stripe(s_c, n, w, (w, 1))
        slr_span, slr_path = slr.permute(2, 0, 1).max(-1)
        # S(j, i) = max(C(i->r) + C(j->r+1)), i <= r < j
        s_s.diagonal(-w).copy_(slr_span)
        p_s.diagonal(-w).copy_(slr_path + starts)
        # S(i, j) = max(C(i->r) + C(j->r+1)), i <= r < j
        s_s.diagonal(w).copy_(slr_span)
        p_s.diagonal(w).copy_(slr_path + starts)

        # C(j->i) = max(C(r->i) + I(j->r)), i <= r < j
        cl = stripe(s_c, n, w, (0, 0), 0) + stripe(s_i, n, w, (w, 0))
        cl_span, cl_path = cl.permute(2, 0, 1).max(-1)
        s_c.diagonal(-w).copy_(cl_span)
        p_c.diagonal(-w).copy_(cl_path + starts)
        # C(i->j) = max(I(i->r) + C(r->j)), i < r <= j
        cr = stripe(s_i, n, w, (0, 1)) + stripe(s_c, n, w, (1, w), 0)
        cr_span, cr_path = cr.permute(2, 0, 1).max(-1)
        s_c.diagonal(w).copy_(cr_span)
        # disable multi words to modify the root
        s_c[0, w][lens.ne(w)] = float('-inf')
        p_c.diagonal(w).copy_(cr_path + starts + 1)

    def backtrack(p_i, p_s, p_c, heads, i, j, flag):
        if i == j:
            return
        if flag == 'c':
            r = p_c[i, j]
            backtrack(p_i, p_s, p_c, heads, i, r, 'i')
            backtrack(p_i, p_s, p_c, heads, r, j, 'c')
        elif flag == 's':
            r = p_s[i, j]
            i, j = sorted((i, j))
            backtrack(p_i, p_s, p_c, heads, i, r, 'c')
            backtrack(p_i, p_s, p_c, heads, j, r + 1, 'c')
        elif flag == 'i':
            r, heads[j] = p_i[i, j], i
            if r == i:
                r = i + 1 if i < j else i - 1
                backtrack(p_i, p_s, p_c, heads, j, r, 'c')
            else:
                backtrack(p_i, p_s, p_c, heads, i, r, 'i')
                backtrack(p_i, p_s, p_c, heads, r, j, 's')

    preds = []
    p_i = p_i.permute(2, 0, 1).cpu()
    p_s = p_s.permute(2, 0, 1).cpu()
    p_c = p_c.permute(2, 0, 1).cpu()
    for i, length in enumerate(lens.tolist()):
        heads = p_c.new_zeros(length + 1, dtype=torch.long)
        backtrack(p_i[i], p_s[i], p_c[i], heads, 0, length, 'c')
        preds.append(heads.to(mask.device))

    return pad(preds, total_length=seq_len).to(mask.device)
예제 #9
0
def eisner(scores, mask):
    r"""
    First-order Eisner algorithm for projective decoding.

    References:
        - Ryan McDonald, Koby Crammer and Fernando Pereira. 2005.
          `Online Large-Margin Training of Dependency Parsers`_.

    Args:
        scores (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
            Scores of all dependent-head pairs.
        mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
            The mask to avoid parsing over padding tokens.
            The first column serving as pseudo words for roots should be ``False``.

    Returns:
        ~torch.Tensor:
            A tensor with shape ``[batch_size, seq_len]`` for the resulting projective parse trees.

    Examples:
        >>> scores = torch.tensor([[[-13.5026, -18.3700, -13.0033, -16.6809],
                                    [-36.5235, -28.6344, -28.4696, -31.6750],
                                    [ -2.9084,  -7.4825,  -1.4861,  -6.8709],
                                    [-29.4880, -27.6905, -26.1498, -27.0233]]])
        >>> mask = torch.tensor([[False,  True,  True,  True]])
        >>> eisner(scores, mask)
        tensor([[0, 2, 0, 2]])

    .. _Online Large-Margin Training of Dependency Parsers:
        https://www.aclweb.org/anthology/P05-1012/
    """

    lens = mask.sum(1)
    batch_size, seq_len, _ = scores.shape
    scores = scores.permute(2, 1, 0)
    s_i = torch.full_like(scores, float('-inf'))
    s_c = torch.full_like(scores, float('-inf'))
    p_i = scores.new_zeros(seq_len, seq_len, batch_size).long()
    p_c = scores.new_zeros(seq_len, seq_len, batch_size).long()
    s_c.diagonal().fill_(0)

    for w in range(1, seq_len):
        n = seq_len - w
        starts = p_i.new_tensor(range(n)).unsqueeze(0)
        # ilr = C(i->r) + C(j->r+1)
        ilr = stripe(s_c, n, w) + stripe(s_c, n, w, (w, 1))
        # [batch_size, n, w]
        il = ir = ilr.permute(2, 0, 1)
        # I(j->i) = max(C(i->r) + C(j->r+1) + s(j->i)), i <= r < j
        il_span, il_path = il.max(-1)
        s_i.diagonal(-w).copy_(il_span + scores.diagonal(-w))
        p_i.diagonal(-w).copy_(il_path + starts)
        # I(i->j) = max(C(i->r) + C(j->r+1) + s(i->j)), i <= r < j
        ir_span, ir_path = ir.max(-1)
        s_i.diagonal(w).copy_(ir_span + scores.diagonal(w))
        p_i.diagonal(w).copy_(ir_path + starts)

        # C(j->i) = max(C(r->i) + I(j->r)), i <= r < j
        cl = stripe(s_c, n, w, (0, 0), 0) + stripe(s_i, n, w, (w, 0))
        cl_span, cl_path = cl.permute(2, 0, 1).max(-1)
        s_c.diagonal(-w).copy_(cl_span)
        p_c.diagonal(-w).copy_(cl_path + starts)
        # C(i->j) = max(I(i->r) + C(r->j)), i < r <= j
        cr = stripe(s_i, n, w, (0, 1)) + stripe(s_c, n, w, (1, w), 0)
        cr_span, cr_path = cr.permute(2, 0, 1).max(-1)
        s_c.diagonal(w).copy_(cr_span)
        s_c[0, w][lens.ne(w)] = float('-inf')
        p_c.diagonal(w).copy_(cr_path + starts + 1)

    def backtrack(p_i, p_c, heads, i, j, complete):
        if i == j:
            return
        if complete:
            r = p_c[i, j]
            backtrack(p_i, p_c, heads, i, r, False)
            backtrack(p_i, p_c, heads, r, j, True)
        else:
            r, heads[j] = p_i[i, j], i
            i, j = sorted((i, j))
            backtrack(p_i, p_c, heads, i, r, True)
            backtrack(p_i, p_c, heads, j, r + 1, True)

    preds = []
    p_c = p_c.permute(2, 0, 1).cpu()
    p_i = p_i.permute(2, 0, 1).cpu()
    for i, length in enumerate(lens.tolist()):
        heads = p_c.new_zeros(length + 1, dtype=torch.long)
        backtrack(p_i[i], p_c[i], heads, 0, length, True)
        preds.append(heads.to(mask.device))

    return pad(preds, total_length=seq_len).to(mask.device)
예제 #10
0
파일: tree.py 프로젝트: yzhangcs/parser
    def forward(self, semiring):
        s_arc, s_sib = self.scores
        batch_size, seq_len = s_arc.shape[:2]
        # [seq_len, seq_len, batch_size, ...], (h->m)
        s_arc = semiring.convert(s_arc.movedim((1, 2), (1, 0)))
        # [seq_len, seq_len, seq_len, batch_size, ...], (h->m->s)
        s_sib = semiring.convert(s_sib.movedim((0, 2), (3, 0)))
        s_i = semiring.zeros_like(s_arc)
        s_s = semiring.zeros_like(s_arc)
        s_c = semiring.zeros_like(s_arc)
        semiring.one_(s_c.diagonal().movedim(-1, 1))

        for w in range(1, seq_len):
            n = seq_len - w

            # INCOMPLETE-L: I(j->i) = <I(j->r), S(j->r, i)> * s(j->i), i < r < j
            #                         <C(j->j), C(i->j-1)>  * s(j->i), otherwise
            # [n, w, batch_size, ...]
            il = semiring.times(
                stripe(s_i, n, w, (w, 1)), stripe(s_s, n, w, (1, 0), 0),
                stripe(s_sib[range(w, n + w), range(n), :], n, w, (0, 1)))
            il[:, -1] = semiring.mul(stripe(s_c, n, 1, (w, w)),
                                     stripe(s_c, n, 1, (0, w - 1))).squeeze(1)
            il = semiring.sum(il, 1)
            s_i.diagonal(-w).copy_(
                semiring.mul(il,
                             s_arc.diagonal(-w).movedim(-1, 0)).movedim(0, -1))
            # INCOMPLETE-R: I(i->j) = <I(i->r), S(i->r, j)> * s(i->j), i < r < j
            #                         <C(i->i), C(j->i+1)>  * s(i->j), otherwise
            # [n, w, batch_size, ...]
            ir = semiring.times(
                stripe(s_i, n, w), stripe(s_s, n, w, (0, w), 0),
                stripe(s_sib[range(n), range(w, n + w), :], n, w))
            if not self.multiroot:
                semiring.zero_(ir[0])
            ir[:, 0] = semiring.mul(stripe(s_c, n, 1),
                                    stripe(s_c, n, 1, (w, 1))).squeeze(1)
            ir = semiring.sum(ir, 1)
            s_i.diagonal(w).copy_(
                semiring.mul(ir,
                             s_arc.diagonal(w).movedim(-1, 0)).movedim(0, -1))

            # [batch_size, ..., n]
            sl = sr = semiring.dot(stripe(s_c, n,
                                          w), stripe(s_c, n, w, (w, 1)),
                                   1).movedim(0, -1)
            # SIB: S(j, i) = <C(i->r), C(j->r+1)>, i <= r < j
            s_s.diagonal(-w).copy_(sl)
            # SIB: S(i, j) = <C(i->r), C(j->r+1)>, i <= r < j
            s_s.diagonal(w).copy_(sr)

            # [n, batch_size, ...]
            # COMPLETE-L: C(j->i) = <C(r->i), I(j->r)>, i <= r < j
            cl = semiring.dot(stripe(s_c, n, w, (0, 0), 0),
                              stripe(s_i, n, w, (w, 0)), 1)
            s_c.diagonal(-w).copy_(cl.movedim(0, -1))
            # COMPLETE-R: C(i->j) = <I(i->r), C(r->j)>, i < r <= j
            cr = semiring.dot(stripe(s_i, n, w, (0, 1)),
                              stripe(s_c, n, w, (1, w), 0), 1)
            s_c.diagonal(w).copy_(cr.movedim(0, -1))
        return semiring.unconvert(s_c)[0][self.lens, range(batch_size)]
예제 #11
0
def cky(scores, mask):
    r"""
    The implementation of `Cocke-Kasami-Younger`_ (CKY) algorithm to parse constituency trees :cite:`zhang-etal-2020-fast`.

    Args:
        scores (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
            Scores of all candidate constituents.
        mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``.
            The mask to avoid parsing over padding tokens.
            For each square matrix in a batch, the positions except upper triangular part should be masked out.

    Returns:
        Sequences of factorized predicted bracketed trees that are traversed in pre-order.

    Examples:
        >>> scores = torch.tensor([[[ 2.5659,  1.4253, -2.5272,  3.3011],
                                    [ 1.3687, -0.5869,  1.0011,  3.3020],
                                    [ 1.2297,  0.4862,  1.1975,  2.5387],
                                    [-0.0511, -1.2541, -0.7577,  0.2659]]])
        >>> mask = torch.tensor([[[False,  True,  True,  True],
                                  [False, False,  True,  True],
                                  [False, False, False,  True],
                                  [False, False, False, False]]])
        >>> cky(scores, mask)
        [[(0, 3), (0, 1), (1, 3), (1, 2), (2, 3)]]

    .. _Cocke-Kasami-Younger:
        https://en.wikipedia.org/wiki/CYK_algorithm
    """

    lens = mask[:, 0].sum(-1)
    scores = scores.permute(1, 2, 3, 0)
    seq_len, seq_len, n_labels, batch_size = scores.shape
    s = scores.new_zeros(seq_len, seq_len, batch_size)
    p_s = scores.new_zeros(seq_len, seq_len, batch_size).long()
    p_l = scores.new_zeros(seq_len, seq_len, batch_size).long()

    for w in range(1, seq_len):
        n = seq_len - w
        starts = p_s.new_tensor(range(n)).unsqueeze(0)
        s_l, p = scores.diagonal(w).max(0)
        p_l.diagonal(w).copy_(p)

        if w == 1:
            s.diagonal(w).copy_(s_l)
            continue
        # [n, w, batch_size]
        s_s = stripe(s, n, w - 1, (0, 1)) + stripe(s, n, w - 1, (1, w), 0)
        # [batch_size, n, w]
        s_s = s_s.permute(2, 0, 1)
        # [batch_size, n]
        s_s, p = s_s.max(-1)
        s.diagonal(w).copy_(s_s + s_l)
        p_s.diagonal(w).copy_(p + starts + 1)

    def backtrack(p_s, p_l, i, j):
        if j == i + 1:
            return [(i, j, p_l[i][j])]
        split, label = p_s[i][j], p_l[i][j]
        ltree = backtrack(p_s, p_l, i, split)
        rtree = backtrack(p_s, p_l, split, j)
        return [(i, j, label)] + ltree + rtree

    p_s = p_s.permute(2, 0, 1).tolist()
    p_l = p_l.permute(2, 0, 1).tolist()
    trees = [
        backtrack(p_s[i], p_l[i], 0, length)
        for i, length in enumerate(lens.tolist())
    ]

    return trees
예제 #12
0
def eisner2o(scores, mask):
    """
    Second-order Eisner algorithm for projective decoding.
    This is an extension of the first-order one and further incorporates sibling scores into tree scoring.

    References:
        - Ryan McDonald and Fernando Pereira (EACL'06)
          Online Learning of Approximate Dependency Parsing Algorithms
          https://www.aclweb.org/anthology/E06-1011/

    Args:
        scores (tuple[torch.Tensor, torch.Tensor]):
            A tuple of two tensors representing the first-order and second-order scores repectively.
            The first ([batch_size, seq_len, seq_len]) holds scores of dependent-head pairs.
            The second ([batch_size, seq_len, seq_len, seq_len]) holds scores of the dependent-head-sibling triples.
        mask (torch.BoolTensor): [batch_size, seq_len]
            Mask to avoid parsing over padding tokens.
            The first column with pseudo words as roots should be set to False.

    Returns:
        Tensor: [batch_size, seq_len]
            Projective parse trees.
    """

    # the end position of each sentence in a batch
    lens = mask.sum(1)
    s_arc, s_sib = scores
    batch_size, seq_len, _ = s_arc.shape
    # [seq_len, seq_len, batch_size]
    s_arc = s_arc.permute(2, 1, 0)
    # [seq_len, seq_len, seq_len, batch_size]
    s_sib = s_sib.permute(2, 1, 3, 0)
    s_i = torch.full_like(s_arc, float('-inf'))
    s_s = torch.full_like(s_arc, float('-inf'))
    s_c = torch.full_like(s_arc, float('-inf'))
    p_i = s_arc.new_zeros(seq_len, seq_len, batch_size).long()
    p_s = s_arc.new_zeros(seq_len, seq_len, batch_size).long()
    p_c = s_arc.new_zeros(seq_len, seq_len, batch_size).long()
    s_c.diagonal().fill_(0)

    for w in range(1, seq_len):
        # n denotes the number of spans to iterate,
        # from span (0, w) to span (n, n+w) given width w
        n = seq_len - w
        starts = p_i.new_tensor(range(n)).unsqueeze(0)
        # I(j->i) = max(I(j->r) + S(j->r, i)), i < r < j |
        #               C(j->j) + C(i->j-1))
        #           + s(j->i)
        # [n, w, batch_size]
        il = stripe(s_i, n, w, (w, 1)) + stripe(s_s, n, w, (1, 0), 0)
        il += stripe(s_sib[range(w, n+w), range(n)], n, w, (0, 1))
        # [n, 1, batch_size]
        il0 = stripe(s_c, n, 1, (w, w)) + stripe(s_c, n, 1, (0, w - 1))
        # il0[0] are set to zeros since the scores of the complete spans starting from 0 are always -inf
        il[:, -1] = il0.index_fill_(0, lens.new_tensor(0), 0).squeeze(1)
        il_span, il_path = il.permute(2, 0, 1).max(-1)
        s_i.diagonal(-w).copy_(il_span + s_arc.diagonal(-w))
        p_i.diagonal(-w).copy_(il_path + starts + 1)
        # I(i->j) = max(I(i->r) + S(i->r, j), i < r < j |
        #               C(i->i) + C(j->i+1))
        #           + s(i->j)
        # [n, w, batch_size]
        ir = stripe(s_i, n, w) + stripe(s_s, n, w, (0, w), 0)
        ir += stripe(s_sib[range(n), range(w, n+w)], n, w)
        ir[0] = float('-inf')
        # [n, 1, batch_size]
        ir0 = stripe(s_c, n, 1) + stripe(s_c, n, 1, (w, 1))
        ir[:, 0] = ir0.squeeze(1)
        ir_span, ir_path = ir.permute(2, 0, 1).max(-1)
        s_i.diagonal(w).copy_(ir_span + s_arc.diagonal(w))
        p_i.diagonal(w).copy_(ir_path + starts)

        # [n, w, batch_size]
        slr = stripe(s_c, n, w) + stripe(s_c, n, w, (w, 1))
        slr_span, slr_path = slr.permute(2, 0, 1).max(-1)
        # S(j, i) = max(C(i->r) + C(j->r+1)), i <= r < j
        s_s.diagonal(-w).copy_(slr_span)
        p_s.diagonal(-w).copy_(slr_path + starts)
        # S(i, j) = max(C(i->r) + C(j->r+1)), i <= r < j
        s_s.diagonal(w).copy_(slr_span)
        p_s.diagonal(w).copy_(slr_path + starts)

        # C(j->i) = max(C(r->i) + I(j->r)), i <= r < j
        cl = stripe(s_c, n, w, (0, 0), 0) + stripe(s_i, n, w, (w, 0))
        cl_span, cl_path = cl.permute(2, 0, 1).max(-1)
        s_c.diagonal(-w).copy_(cl_span)
        p_c.diagonal(-w).copy_(cl_path + starts)
        # C(i->j) = max(I(i->r) + C(r->j)), i < r <= j
        cr = stripe(s_i, n, w, (0, 1)) + stripe(s_c, n, w, (1, w), 0)
        cr_span, cr_path = cr.permute(2, 0, 1).max(-1)
        s_c.diagonal(w).copy_(cr_span)
        # disable multi words to modify the root
        s_c[0, w][lens.ne(w)] = float('-inf')
        p_c.diagonal(w).copy_(cr_path + starts + 1)

    def backtrack(p_i, p_s, p_c, heads, i, j, flag):
        if i == j:
            return
        if flag == 'c':
            r = p_c[i, j]
            backtrack(p_i, p_s, p_c, heads, i, r, 'i')
            backtrack(p_i, p_s, p_c, heads, r, j, 'c')
        elif flag == 's':
            r = p_s[i, j]
            i, j = sorted((i, j))
            backtrack(p_i, p_s, p_c, heads, i, r, 'c')
            backtrack(p_i, p_s, p_c, heads, j, r + 1, 'c')
        elif flag == 'i':
            r, heads[j] = p_i[i, j], i
            if r == i:
                r = i + 1 if i < j else i - 1
                backtrack(p_i, p_s, p_c, heads, j, r, 'c')
            else:
                backtrack(p_i, p_s, p_c, heads, i, r, 'i')
                backtrack(p_i, p_s, p_c, heads, r, j, 's')

    preds = []
    p_i = p_i.permute(2, 0, 1).cpu()
    p_s = p_s.permute(2, 0, 1).cpu()
    p_c = p_c.permute(2, 0, 1).cpu()
    for i, length in enumerate(lens.tolist()):
        heads = p_c.new_zeros(length + 1, dtype=torch.long)
        backtrack(p_i[i], p_s[i], p_c[i], heads, 0, length, 'c')
        preds.append(heads.to(mask.device))

    return pad(preds, total_length=seq_len).to(mask.device)
예제 #13
0
def eisner(scores, mask):
    """
    First-order Eisner algorithm for projective decoding.

    References:
        - Ryan McDonald, Koby Crammer and Fernando Pereira (ACL'05)
          Online Large-Margin Training of Dependency Parsers
          https://www.aclweb.org/anthology/P05-1012/

    Args:
        scores (torch.Tensor): [batch_size, seq_len, seq_len]
            The scores of dependent-head pairs.
        mask (torch.BoolTensor): [batch_size, seq_len]
            Mask to avoid parsing over padding tokens.
            The first column with pseudo words as roots should be set to False.

    Returns:
        Tensor: [batch_size, seq_len]
            Projective parse trees.
    """

    lens = mask.sum(1)
    batch_size, seq_len, _ = scores.shape
    scores = scores.permute(2, 1, 0)
    s_i = torch.full_like(scores, float('-inf'))
    s_c = torch.full_like(scores, float('-inf'))
    p_i = scores.new_zeros(seq_len, seq_len, batch_size).long()
    p_c = scores.new_zeros(seq_len, seq_len, batch_size).long()
    s_c.diagonal().fill_(0)

    for w in range(1, seq_len):
        n = seq_len - w
        starts = p_i.new_tensor(range(n)).unsqueeze(0)
        # ilr = C(i->r) + C(j->r+1)
        ilr = stripe(s_c, n, w) + stripe(s_c, n, w, (w, 1))
        # [batch_size, n, w]
        il = ir = ilr.permute(2, 0, 1)
        # I(j->i) = max(C(i->r) + C(j->r+1) + s(j->i)), i <= r < j
        il_span, il_path = il.max(-1)
        s_i.diagonal(-w).copy_(il_span + scores.diagonal(-w))
        p_i.diagonal(-w).copy_(il_path + starts)
        # I(i->j) = max(C(i->r) + C(j->r+1) + s(i->j)), i <= r < j
        ir_span, ir_path = ir.max(-1)
        s_i.diagonal(w).copy_(ir_span + scores.diagonal(w))
        p_i.diagonal(w).copy_(ir_path + starts)

        # C(j->i) = max(C(r->i) + I(j->r)), i <= r < j
        cl = stripe(s_c, n, w, (0, 0), 0) + stripe(s_i, n, w, (w, 0))
        cl_span, cl_path = cl.permute(2, 0, 1).max(-1)
        s_c.diagonal(-w).copy_(cl_span)
        p_c.diagonal(-w).copy_(cl_path + starts)
        # C(i->j) = max(I(i->r) + C(r->j)), i < r <= j
        cr = stripe(s_i, n, w, (0, 1)) + stripe(s_c, n, w, (1, w), 0)
        cr_span, cr_path = cr.permute(2, 0, 1).max(-1)
        s_c.diagonal(w).copy_(cr_span)
        s_c[0, w][lens.ne(w)] = float('-inf')
        p_c.diagonal(w).copy_(cr_path + starts + 1)

    def backtrack(p_i, p_c, heads, i, j, complete):
        if i == j:
            return
        if complete:
            r = p_c[i, j]
            backtrack(p_i, p_c, heads, i, r, False)
            backtrack(p_i, p_c, heads, r, j, True)
        else:
            r, heads[j] = p_i[i, j], i
            i, j = sorted((i, j))
            backtrack(p_i, p_c, heads, i, r, True)
            backtrack(p_i, p_c, heads, j, r + 1, True)

    preds = []
    p_c = p_c.permute(2, 0, 1).cpu()
    p_i = p_i.permute(2, 0, 1).cpu()
    for i, length in enumerate(lens.tolist()):
        heads = p_c.new_zeros(length + 1, dtype=torch.long)
        backtrack(p_i[i], p_c[i], heads, 0, length, True)
        preds.append(heads.to(mask.device))

    return pad(preds, total_length=seq_len).to(mask.device)
    def inside(self, scores, mask, cands=None):
        # the end position of each sentence in a batch
        lens = mask.sum(1)
        s_arc, s_sib = scores
        batch_size, seq_len, _ = s_arc.shape
        # [seq_len, seq_len, batch_size]
        s_arc = s_arc.permute(2, 1, 0)
        # [seq_len, seq_len, seq_len, batch_size]
        s_sib = s_sib.permute(2, 1, 3, 0)
        s_i = torch.full_like(s_arc, float('-inf'))
        s_s = torch.full_like(s_arc, float('-inf'))
        s_c = torch.full_like(s_arc, float('-inf'))
        s_c.diagonal().fill_(0)

        # set the scores of arcs excluded by cands to -inf
        if cands is not None:
            mask = mask.index_fill(1, lens.new_tensor(0), 1)
            mask = (mask.unsqueeze(1) & mask.unsqueeze(-1)).permute(2, 1, 0)
            cands = cands.unsqueeze(-1).index_fill(1, lens.new_tensor(0), -1)
            cands = cands.eq(lens.new_tensor(range(seq_len))) | cands.lt(0)
            cands = cands.permute(2, 1, 0) & mask
            s_arc = s_arc.masked_fill(~cands, float('-inf'))

        for w in range(1, seq_len):
            # n denotes the number of spans to iterate,
            # from span (0, w) to span (n, n+w) given width w
            n = seq_len - w
            # I(j->i) = logsum(exp(I(j->r) + S(j->r, i)) +, i < r < j
            #                  exp(C(j->j) + C(i->j-1)))
            #           + s(j->i)
            # [n, w, batch_size]
            il = stripe(s_i, n, w, (w, 1)) + stripe(s_s, n, w, (1, 0), 0)
            il += stripe(s_sib[range(w, n + w), range(n)], n, w, (0, 1))
            # [n, 1, batch_size]
            il0 = stripe(s_c, n, 1, (w, w)) + stripe(s_c, n, 1, (0, w - 1))
            # il0[0] are set to zeros since the scores of the complete spans starting from 0 are always -inf
            il[:, -1] = il0.index_fill_(0, lens.new_tensor(0), 0).squeeze(1)
            if il.requires_grad:
                il.register_hook(lambda x: x.masked_fill_(torch.isnan(x), 0))
            il = il.permute(2, 0, 1).logsumexp(-1)
            s_i.diagonal(-w).copy_(il + s_arc.diagonal(-w))
            # I(i->j) = logsum(exp(I(i->r) + S(i->r, j)) +, i < r < j
            #                  exp(C(i->i) + C(j->i+1)))
            #           + s(i->j)
            # [n, w, batch_size]
            ir = stripe(s_i, n, w) + stripe(s_s, n, w, (0, w), 0)
            ir += stripe(s_sib[range(n), range(w, n + w)], n, w)
            ir[0] = float('-inf')
            # [n, 1, batch_size]
            ir0 = stripe(s_c, n, 1) + stripe(s_c, n, 1, (w, 1))
            ir[:, 0] = ir0.squeeze(1)
            if ir.requires_grad:
                ir.register_hook(lambda x: x.masked_fill_(torch.isnan(x), 0))
            ir = ir.permute(2, 0, 1).logsumexp(-1)
            s_i.diagonal(w).copy_(ir + s_arc.diagonal(w))

            # [n, w, batch_size]
            slr = stripe(s_c, n, w) + stripe(s_c, n, w, (w, 1))
            if slr.requires_grad:
                slr.register_hook(lambda x: x.masked_fill_(torch.isnan(x), 0))
            slr = slr.permute(2, 0, 1).logsumexp(-1)
            # S(j, i) = logsumexp(C(i->r) + C(j->r+1)), i <= r < j
            s_s.diagonal(-w).copy_(slr)
            # S(i, j) = logsumexp(C(i->r) + C(j->r+1)), i <= r < j
            s_s.diagonal(w).copy_(slr)

            # C(j->i) = logsumexp(C(r->i) + I(j->r)), i <= r < j
            cl = stripe(s_c, n, w, (0, 0), 0) + stripe(s_i, n, w, (w, 0))
            cl.register_hook(lambda x: x.masked_fill_(torch.isnan(x), 0))
            s_c.diagonal(-w).copy_(cl.permute(2, 0, 1).logsumexp(-1))
            # C(i->j) = logsumexp(I(i->r) + C(r->j)), i < r <= j
            cr = stripe(s_i, n, w, (0, 1)) + stripe(s_c, n, w, (1, w), 0)
            cr.register_hook(lambda x: x.masked_fill_(torch.isnan(x), 0))
            s_c.diagonal(w).copy_(cr.permute(2, 0, 1).logsumexp(-1))
            # disable multi words to modify the root
            s_c[0, w][lens.ne(w)] = float('-inf')

        return s_c[0].gather(0, lens.unsqueeze(0)).sum()