def inside(self, scores, mask): lens = mask[:, 0].sum(-1) batch_size, seq_len, _ = scores.shape # [seq_len, seq_len, batch_size] scores, mask = scores.permute(1, 2, 0), mask.permute(1, 2, 0) s = torch.full_like(scores, float('-inf')) for w in range(1, seq_len): # n denotes the number of spans to iterate, # from span (0, w) to span (n, n+w) given width w n = seq_len - w if w == 1: s.diagonal(w).copy_(scores.diagonal(w)) continue # [n, w, batch_size] s_s = stripe(s, n, w - 1, (0, 1)) + stripe(s, n, w - 1, (1, w), 0) # [batch_size, n, w] s_s = s_s.permute(2, 0, 1) if s_s.requires_grad: s_s.register_hook(lambda x: x.masked_fill_(torch.isnan(x), 0)) s_s = s_s.logsumexp(-1) s.diagonal(w).copy_(s_s + scores.diagonal(w)) return s[0].gather(0, lens.unsqueeze(0)).sum()
def cky(scores, mask): """ The implementation of Cocke-Kasami-Younger (CKY) algorithm to parse constituency trees. References: - Yu Zhang, Houquan Zhou and Zhenghua Li (IJCAI'20) Fast and Accurate Neural CRF Constituency Parsing https://www.ijcai.org/Proceedings/2020/560/ Args: scores (torch.Tensor): [batch_size seq_len, seq_len] The scores of all candidate constituents. mask (torch.BoolTensor): [batch_size, seq_len, seq_len] Mask to avoid parsing over padding tokens. For each square matrix in a batch, the positions except upper triangular part should be masked out. Returns: trees (list[list[tuple]]): The sequences of factorized predicted bracketed trees traversed in pre-order. """ lens = mask[:, 0].sum(-1) scores = scores.permute(1, 2, 0) seq_len, seq_len, batch_size = scores.shape s = scores.new_zeros(seq_len, seq_len, batch_size) p = scores.new_zeros(seq_len, seq_len, batch_size).long() for w in range(1, seq_len): n = seq_len - w starts = p.new_tensor(range(n)).unsqueeze(0) if w == 1: s.diagonal(w).copy_(scores.diagonal(w)) continue # [n, w, batch_size] s_span = stripe(s, n, w - 1, (0, 1)) + stripe(s, n, w - 1, (1, w), 0) # [batch_size, n, w] s_span = s_span.permute(2, 0, 1) # [batch_size, n] s_span, p_span = s_span.max(-1) s.diagonal(w).copy_(s_span + scores.diagonal(w)) p.diagonal(w).copy_(p_span + starts + 1) def backtrack(p, i, j): if j == i + 1: return [(i, j)] split = p[i][j] ltree = backtrack(p, i, split) rtree = backtrack(p, split, j) return [(i, j)] + ltree + rtree p = p.permute(2, 0, 1).tolist() trees = [ backtrack(p[i], 0, length) for i, length in enumerate(lens.tolist()) ] return trees
def forward(self, semiring): batch_size, seq_len = self.scores.shape[:2] # [seq_len, seq_len, batch_size, ...], (l->r) scores = semiring.convert(self.scores.movedim((1, 2), (0, 1))) s = semiring.zeros_like(scores) s.diagonal(1).copy_(scores.diagonal(1)) for w in range(2, seq_len): n = seq_len - w # [n, batch_size, ...] s_s = semiring.dot(stripe(s, n, w - 1, (0, 1)), stripe(s, n, w - 1, (1, w), False), 1) s.diagonal(w).copy_( semiring.mul(s_s, scores.diagonal(w).movedim(-1, 0)).movedim(0, -1)) return semiring.unconvert(s)[0][self.lens, range(batch_size)]
def forward(self, semiring): s_dep, s_con = self.scores[0], self.scores[1] batch_size, seq_len, *_ = s_con.shape # [seq_len, seq_len, batch_size, ...], (m<-h) s_dep = semiring.convert(s_dep.movedim(0, 2)) s_root, s_dep = s_dep[1:, 0], s_dep[1:, 1:] # [seq_len, seq_len, batch_size, ...], (l->r) s_con = semiring.convert(s_con.movedim((1, 2), (0, 1))) # [seq_len, seq_len, seq_len, batch_size, ...], (i, j, h) s_span = semiring.zero_( s_con.new_empty(seq_len, seq_len, seq_len - 1, *s_con.shape[2:])) # [seq_len, seq_len, seq_len, batch_size, ...], (i, j<-h) s_hook = semiring.zero_( s_con.new_empty(seq_len, seq_len, seq_len - 1, *s_con.shape[2:])) diagonal_stripe(s_span, 1).copy_( s_con.diagonal(1).movedim(-1, 0).unsqueeze(1)) s_hook.diagonal(1).copy_( semiring.mul(s_dep, s_con.diagonal(1).movedim(-1, 0).unsqueeze(1)).movedim( 0, -1)) for w in range(2, seq_len): n = seq_len - w # COMPLETE-L: s_span_l(i, j, h) = <s_span(i, k, h), s_hook(h->k, j)>, i < k < j # [n, w, batch_size, ...] s_l = stripe( semiring.dot(stripe(s_span, n, w - 1, (0, 1)), stripe(s_hook, n, w - 1, (1, w), False), 1), n, w) # COMPLETE-R: s_span_r(i, j, h) = <s_hook(i, k<-h), s_span(k, j, h)>, i < k < j # [n, w, batch_size, ...] s_r = stripe( semiring.dot(stripe(s_hook, n, w - 1, (0, 1)), stripe(s_span, n, w - 1, (1, w), False), 1), n, w) # COMPLETE: s_span(i, j, h) = s_span_l(i, j, h) + s_span_r(i, j, h) + s(i, j) # [n, w, batch_size, ...] s = semiring.mul(semiring.sum(torch.stack((s_l, s_r)), 0), s_con.diagonal(w).movedim(-1, 0).unsqueeze(1)) diagonal_stripe(s_span, w).copy_(s) if w == seq_len - 1: continue # ATTACH: s_hook(h->i, j) = <s(h->m), s_span(i, j, m)>, i <= m < j # [n, seq_len, batch_size, ...] s = semiring.dot(expanded_stripe(s_dep, n, w), diagonal_stripe(s_span, w).unsqueeze(2), 1) s_hook.diagonal(w).copy_(s.movedim(0, -1)) return semiring.unconvert( semiring.dot( s_span[0][self.lens, :, range(batch_size)].transpose(0, 1), s_root, 0))
def inside(self, scores, mask, cands=None): # the end position of each sentence in a batch lens = mask.sum(1) batch_size, seq_len, _ = scores.shape # [seq_len, seq_len, batch_size] scores = scores.permute(2, 1, 0) s_i = torch.full_like(scores, float('-inf')) s_c = torch.full_like(scores, float('-inf')) s_c.diagonal().fill_(0) # set the scores of arcs excluded by cands to -inf if cands is not None: mask = mask.index_fill(1, lens.new_tensor(0), 1) mask = (mask.unsqueeze(1) & mask.unsqueeze(-1)).permute(2, 1, 0) cands = cands.unsqueeze(-1).index_fill(1, lens.new_tensor(0), -1) cands = cands.eq(lens.new_tensor(range(seq_len))) | cands.lt(0) cands = cands.permute(2, 1, 0) & mask scores = scores.masked_fill(~cands, float('-inf')) for w in range(1, seq_len): # n denotes the number of spans to iterate, # from span (0, w) to span (n, n+w) given width w n = seq_len - w # ilr = C(i->r) + C(j->r+1) # [n, w, batch_size] ilr = stripe(s_c, n, w) + stripe(s_c, n, w, (w, 1)) if ilr.requires_grad: ilr.register_hook(lambda x: x.masked_fill_(torch.isnan(x), 0)) il = ir = ilr.permute(2, 0, 1).logsumexp(-1) # I(j->i) = logsumexp(C(i->r) + C(j->r+1)) + s(j->i), i <= r < j # fill the w-th diagonal of the lower triangular part of s_i # with I(j->i) of n spans s_i.diagonal(-w).copy_(il + scores.diagonal(-w)) # I(i->j) = logsumexp(C(i->r) + C(j->r+1)) + s(i->j), i <= r < j # fill the w-th diagonal of the upper triangular part of s_i # with I(i->j) of n spans s_i.diagonal(w).copy_(ir + scores.diagonal(w)) # C(j->i) = logsumexp(C(r->i) + I(j->r)), i <= r < j cl = stripe(s_c, n, w, (0, 0), 0) + stripe(s_i, n, w, (w, 0)) cl.register_hook(lambda x: x.masked_fill_(torch.isnan(x), 0)) s_c.diagonal(-w).copy_(cl.permute(2, 0, 1).logsumexp(-1)) # C(i->j) = logsumexp(I(i->r) + C(r->j)), i < r <= j cr = stripe(s_i, n, w, (0, 1)) + stripe(s_c, n, w, (1, w), 0) cr.register_hook(lambda x: x.masked_fill_(torch.isnan(x), 0)) s_c.diagonal(w).copy_(cr.permute(2, 0, 1).logsumexp(-1)) # disable multi words to modify the root s_c[0, w][lens.ne(w)] = float('-inf') return s_c[0].gather(0, lens.unsqueeze(0)).sum()
def forward(self, semiring): s_arc = self.scores batch_size, seq_len = s_arc.shape[:2] # [seq_len, seq_len, batch_size, ...], (h->m) s_arc = semiring.convert(s_arc.movedim((1, 2), (1, 0))) s_i = semiring.zeros_like(s_arc) s_c = semiring.zeros_like(s_arc) semiring.one_(s_c.diagonal().movedim(-1, 1)) for w in range(1, seq_len): n = seq_len - w # [n, batch_size, ...] il = ir = semiring.dot(stripe(s_c, n, w), stripe(s_c, n, w, (w, 1)), 1) # INCOMPLETE-L: I(j->i) = <C(i->r), C(j->r+1)> * s(j->i), i <= r < j # fill the w-th diagonal of the lower triangular part of s_i with I(j->i) of n spans s_i.diagonal(-w).copy_( semiring.mul(il, s_arc.diagonal(-w).movedim(-1, 0)).movedim(0, -1)) # INCOMPLETE-R: I(i->j) = <C(i->r), C(j->r+1)> * s(i->j), i <= r < j # fill the w-th diagonal of the upper triangular part of s_i with I(i->j) of n spans s_i.diagonal(w).copy_( semiring.mul(ir, s_arc.diagonal(w).movedim(-1, 0)).movedim(0, -1)) # [n, batch_size, ...] # COMPLETE-L: C(j->i) = <C(r->i), I(j->r)>, i <= r < j cl = semiring.dot(stripe(s_c, n, w, (0, 0), 0), stripe(s_i, n, w, (w, 0)), 1) s_c.diagonal(-w).copy_(cl.movedim(0, -1)) # COMPLETE-R: C(i->j) = <I(i->r), C(r->j)>, i < r <= j cr = semiring.dot(stripe(s_i, n, w, (0, 1)), stripe(s_c, n, w, (1, w), 0), 1) s_c.diagonal(w).copy_(cr.movedim(0, -1)) if not self.multiroot: s_c[0, w][self.lens.ne(w)] = semiring.zero return semiring.unconvert(s_c)[0][self.lens, range(batch_size)]
def cky(scores, mask): r""" The implementation of `Cocke-Kasami-Younger`_ (CKY) algorithm to parse constituency trees. References: - Yu Zhang, Houquan Zhou and Zhenghua Li. 2020. `Fast and Accurate Neural CRF Constituency Parsing`_. Args: scores (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. Scores of all candidate constituents. mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``. The mask to avoid parsing over padding tokens. For each square matrix in a batch, the positions except upper triangular part should be masked out. Returns: Sequences of factorized predicted bracketed trees that are traversed in pre-order. Examples: >>> scores = torch.tensor([[[ 2.5659, 1.4253, -2.5272, 3.3011], [ 1.3687, -0.5869, 1.0011, 3.3020], [ 1.2297, 0.4862, 1.1975, 2.5387], [-0.0511, -1.2541, -0.7577, 0.2659]]]) >>> mask = torch.tensor([[[False, True, True, True], [False, False, True, True], [False, False, False, True], [False, False, False, False]]]) >>> cky(scores, mask) [[(0, 3), (0, 1), (1, 3), (1, 2), (2, 3)]] .. _Cocke-Kasami-Younger: https://en.wikipedia.org/wiki/CYK_algorithm .. _Fast and Accurate Neural CRF Constituency Parsing: https://www.ijcai.org/Proceedings/2020/560/ """ lens = mask[:, 0].sum(-1) scores = scores.permute(1, 2, 0) seq_len, seq_len, batch_size = scores.shape s = scores.new_zeros(seq_len, seq_len, batch_size) p = scores.new_zeros(seq_len, seq_len, batch_size).long() for w in range(1, seq_len): n = seq_len - w starts = p.new_tensor(range(n)).unsqueeze(0) if w == 1: s.diagonal(w).copy_(scores.diagonal(w)) continue # [n, w, batch_size] s_span = stripe(s, n, w - 1, (0, 1)) + stripe(s, n, w - 1, (1, w), 0) # [batch_size, n, w] s_span = s_span.permute(2, 0, 1) # [batch_size, n] s_span, p_span = s_span.max(-1) s.diagonal(w).copy_(s_span + scores.diagonal(w)) p.diagonal(w).copy_(p_span + starts + 1) def backtrack(p, i, j): if j == i + 1: return [(i, j)] split = p[i][j] ltree = backtrack(p, i, split) rtree = backtrack(p, split, j) return [(i, j)] + ltree + rtree p = p.permute(2, 0, 1).tolist() trees = [ backtrack(p[i], 0, length) for i, length in enumerate(lens.tolist()) ] return trees
def eisner2o(scores, mask): r""" Second-order Eisner algorithm for projective decoding. This is an extension of the first-order one that further incorporates sibling scores into tree scoring. References: - Ryan McDonald and Fernando Pereira. 2006. `Online Learning of Approximate Dependency Parsing Algorithms`_. Args: scores (~torch.Tensor, ~torch.Tensor): A tuple of two tensors representing the first-order and second-order scores repectively. The first (``[batch_size, seq_len, seq_len]``) holds scores of all dependent-head pairs. The second (``[batch_size, seq_len, seq_len, seq_len]``) holds scores of all dependent-head-sibling triples. mask (~torch.BoolTensor): ``[batch_size, seq_len]``. The mask to avoid parsing over padding tokens. The first column serving as pseudo words for roots should be ``False``. Returns: ~torch.Tensor: A tensor with shape ``[batch_size, seq_len]`` for the resulting projective parse trees. Examples: >>> s_arc = torch.tensor([[[ -2.8092, -7.9104, -0.9414, -5.4360], [-10.3494, -7.9298, -3.6929, -7.3985], [ 1.1815, -3.8291, 2.3166, -2.7183], [ -3.9776, -3.9063, -1.6762, -3.1861]]]) >>> s_sib = torch.tensor([[[[ 0.4719, 0.4154, 1.1333, 0.6946], [ 1.1252, 1.3043, 2.1128, 1.4621], [ 0.5974, 0.5635, 1.0115, 0.7550], [ 1.1174, 1.3794, 2.2567, 1.4043]], [[-2.1480, -4.1830, -2.5519, -1.8020], [-1.2496, -1.7859, -0.0665, -0.4938], [-2.6171, -4.0142, -2.9428, -2.2121], [-0.5166, -1.0925, 0.5190, 0.1371]], [[ 0.5827, -1.2499, -0.0648, -0.0497], [ 1.4695, 0.3522, 1.5614, 1.0236], [ 0.4647, -0.7996, -0.3801, 0.0046], [ 1.5611, 0.3875, 1.8285, 1.0766]], [[-1.3053, -2.9423, -1.5779, -1.2142], [-0.1908, -0.9699, 0.3085, 0.1061], [-1.6783, -2.8199, -1.8853, -1.5653], [ 0.3629, -0.3488, 0.9011, 0.5674]]]]) >>> mask = torch.tensor([[False, True, True, True]]) >>> eisner2o((s_arc, s_sib), mask) tensor([[0, 2, 0, 2]]) .. _Online Learning of Approximate Dependency Parsing Algorithms: https://www.aclweb.org/anthology/E06-1011/ """ # the end position of each sentence in a batch lens = mask.sum(1) s_arc, s_sib = scores batch_size, seq_len, _ = s_arc.shape # [seq_len, seq_len, batch_size] s_arc = s_arc.permute(2, 1, 0) # [seq_len, seq_len, seq_len, batch_size] s_sib = s_sib.permute(2, 1, 3, 0) s_i = torch.full_like(s_arc, float('-inf')) s_s = torch.full_like(s_arc, float('-inf')) s_c = torch.full_like(s_arc, float('-inf')) p_i = s_arc.new_zeros(seq_len, seq_len, batch_size).long() p_s = s_arc.new_zeros(seq_len, seq_len, batch_size).long() p_c = s_arc.new_zeros(seq_len, seq_len, batch_size).long() s_c.diagonal().fill_(0) for w in range(1, seq_len): # n denotes the number of spans to iterate, # from span (0, w) to span (n, n+w) given width w n = seq_len - w starts = p_i.new_tensor(range(n)).unsqueeze(0) # I(j->i) = max(I(j->r) + S(j->r, i)), i < r < j | # C(j->j) + C(i->j-1)) # + s(j->i) # [n, w, batch_size] il = stripe(s_i, n, w, (w, 1)) + stripe(s_s, n, w, (1, 0), 0) il += stripe(s_sib[range(w, n + w), range(n)], n, w, (0, 1)) # [n, 1, batch_size] il0 = stripe(s_c, n, 1, (w, w)) + stripe(s_c, n, 1, (0, w - 1)) # il0[0] are set to zeros since the scores of the complete spans starting from 0 are always -inf il[:, -1] = il0.index_fill_(0, lens.new_tensor(0), 0).squeeze(1) il_span, il_path = il.permute(2, 0, 1).max(-1) s_i.diagonal(-w).copy_(il_span + s_arc.diagonal(-w)) p_i.diagonal(-w).copy_(il_path + starts + 1) # I(i->j) = max(I(i->r) + S(i->r, j), i < r < j | # C(i->i) + C(j->i+1)) # + s(i->j) # [n, w, batch_size] ir = stripe(s_i, n, w) + stripe(s_s, n, w, (0, w), 0) ir += stripe(s_sib[range(n), range(w, n + w)], n, w) ir[0] = float('-inf') # [n, 1, batch_size] ir0 = stripe(s_c, n, 1) + stripe(s_c, n, 1, (w, 1)) ir[:, 0] = ir0.squeeze(1) ir_span, ir_path = ir.permute(2, 0, 1).max(-1) s_i.diagonal(w).copy_(ir_span + s_arc.diagonal(w)) p_i.diagonal(w).copy_(ir_path + starts) # [n, w, batch_size] slr = stripe(s_c, n, w) + stripe(s_c, n, w, (w, 1)) slr_span, slr_path = slr.permute(2, 0, 1).max(-1) # S(j, i) = max(C(i->r) + C(j->r+1)), i <= r < j s_s.diagonal(-w).copy_(slr_span) p_s.diagonal(-w).copy_(slr_path + starts) # S(i, j) = max(C(i->r) + C(j->r+1)), i <= r < j s_s.diagonal(w).copy_(slr_span) p_s.diagonal(w).copy_(slr_path + starts) # C(j->i) = max(C(r->i) + I(j->r)), i <= r < j cl = stripe(s_c, n, w, (0, 0), 0) + stripe(s_i, n, w, (w, 0)) cl_span, cl_path = cl.permute(2, 0, 1).max(-1) s_c.diagonal(-w).copy_(cl_span) p_c.diagonal(-w).copy_(cl_path + starts) # C(i->j) = max(I(i->r) + C(r->j)), i < r <= j cr = stripe(s_i, n, w, (0, 1)) + stripe(s_c, n, w, (1, w), 0) cr_span, cr_path = cr.permute(2, 0, 1).max(-1) s_c.diagonal(w).copy_(cr_span) # disable multi words to modify the root s_c[0, w][lens.ne(w)] = float('-inf') p_c.diagonal(w).copy_(cr_path + starts + 1) def backtrack(p_i, p_s, p_c, heads, i, j, flag): if i == j: return if flag == 'c': r = p_c[i, j] backtrack(p_i, p_s, p_c, heads, i, r, 'i') backtrack(p_i, p_s, p_c, heads, r, j, 'c') elif flag == 's': r = p_s[i, j] i, j = sorted((i, j)) backtrack(p_i, p_s, p_c, heads, i, r, 'c') backtrack(p_i, p_s, p_c, heads, j, r + 1, 'c') elif flag == 'i': r, heads[j] = p_i[i, j], i if r == i: r = i + 1 if i < j else i - 1 backtrack(p_i, p_s, p_c, heads, j, r, 'c') else: backtrack(p_i, p_s, p_c, heads, i, r, 'i') backtrack(p_i, p_s, p_c, heads, r, j, 's') preds = [] p_i = p_i.permute(2, 0, 1).cpu() p_s = p_s.permute(2, 0, 1).cpu() p_c = p_c.permute(2, 0, 1).cpu() for i, length in enumerate(lens.tolist()): heads = p_c.new_zeros(length + 1, dtype=torch.long) backtrack(p_i[i], p_s[i], p_c[i], heads, 0, length, 'c') preds.append(heads.to(mask.device)) return pad(preds, total_length=seq_len).to(mask.device)
def eisner(scores, mask): r""" First-order Eisner algorithm for projective decoding. References: - Ryan McDonald, Koby Crammer and Fernando Pereira. 2005. `Online Large-Margin Training of Dependency Parsers`_. Args: scores (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. Scores of all dependent-head pairs. mask (~torch.BoolTensor): ``[batch_size, seq_len]``. The mask to avoid parsing over padding tokens. The first column serving as pseudo words for roots should be ``False``. Returns: ~torch.Tensor: A tensor with shape ``[batch_size, seq_len]`` for the resulting projective parse trees. Examples: >>> scores = torch.tensor([[[-13.5026, -18.3700, -13.0033, -16.6809], [-36.5235, -28.6344, -28.4696, -31.6750], [ -2.9084, -7.4825, -1.4861, -6.8709], [-29.4880, -27.6905, -26.1498, -27.0233]]]) >>> mask = torch.tensor([[False, True, True, True]]) >>> eisner(scores, mask) tensor([[0, 2, 0, 2]]) .. _Online Large-Margin Training of Dependency Parsers: https://www.aclweb.org/anthology/P05-1012/ """ lens = mask.sum(1) batch_size, seq_len, _ = scores.shape scores = scores.permute(2, 1, 0) s_i = torch.full_like(scores, float('-inf')) s_c = torch.full_like(scores, float('-inf')) p_i = scores.new_zeros(seq_len, seq_len, batch_size).long() p_c = scores.new_zeros(seq_len, seq_len, batch_size).long() s_c.diagonal().fill_(0) for w in range(1, seq_len): n = seq_len - w starts = p_i.new_tensor(range(n)).unsqueeze(0) # ilr = C(i->r) + C(j->r+1) ilr = stripe(s_c, n, w) + stripe(s_c, n, w, (w, 1)) # [batch_size, n, w] il = ir = ilr.permute(2, 0, 1) # I(j->i) = max(C(i->r) + C(j->r+1) + s(j->i)), i <= r < j il_span, il_path = il.max(-1) s_i.diagonal(-w).copy_(il_span + scores.diagonal(-w)) p_i.diagonal(-w).copy_(il_path + starts) # I(i->j) = max(C(i->r) + C(j->r+1) + s(i->j)), i <= r < j ir_span, ir_path = ir.max(-1) s_i.diagonal(w).copy_(ir_span + scores.diagonal(w)) p_i.diagonal(w).copy_(ir_path + starts) # C(j->i) = max(C(r->i) + I(j->r)), i <= r < j cl = stripe(s_c, n, w, (0, 0), 0) + stripe(s_i, n, w, (w, 0)) cl_span, cl_path = cl.permute(2, 0, 1).max(-1) s_c.diagonal(-w).copy_(cl_span) p_c.diagonal(-w).copy_(cl_path + starts) # C(i->j) = max(I(i->r) + C(r->j)), i < r <= j cr = stripe(s_i, n, w, (0, 1)) + stripe(s_c, n, w, (1, w), 0) cr_span, cr_path = cr.permute(2, 0, 1).max(-1) s_c.diagonal(w).copy_(cr_span) s_c[0, w][lens.ne(w)] = float('-inf') p_c.diagonal(w).copy_(cr_path + starts + 1) def backtrack(p_i, p_c, heads, i, j, complete): if i == j: return if complete: r = p_c[i, j] backtrack(p_i, p_c, heads, i, r, False) backtrack(p_i, p_c, heads, r, j, True) else: r, heads[j] = p_i[i, j], i i, j = sorted((i, j)) backtrack(p_i, p_c, heads, i, r, True) backtrack(p_i, p_c, heads, j, r + 1, True) preds = [] p_c = p_c.permute(2, 0, 1).cpu() p_i = p_i.permute(2, 0, 1).cpu() for i, length in enumerate(lens.tolist()): heads = p_c.new_zeros(length + 1, dtype=torch.long) backtrack(p_i[i], p_c[i], heads, 0, length, True) preds.append(heads.to(mask.device)) return pad(preds, total_length=seq_len).to(mask.device)
def forward(self, semiring): s_arc, s_sib = self.scores batch_size, seq_len = s_arc.shape[:2] # [seq_len, seq_len, batch_size, ...], (h->m) s_arc = semiring.convert(s_arc.movedim((1, 2), (1, 0))) # [seq_len, seq_len, seq_len, batch_size, ...], (h->m->s) s_sib = semiring.convert(s_sib.movedim((0, 2), (3, 0))) s_i = semiring.zeros_like(s_arc) s_s = semiring.zeros_like(s_arc) s_c = semiring.zeros_like(s_arc) semiring.one_(s_c.diagonal().movedim(-1, 1)) for w in range(1, seq_len): n = seq_len - w # INCOMPLETE-L: I(j->i) = <I(j->r), S(j->r, i)> * s(j->i), i < r < j # <C(j->j), C(i->j-1)> * s(j->i), otherwise # [n, w, batch_size, ...] il = semiring.times( stripe(s_i, n, w, (w, 1)), stripe(s_s, n, w, (1, 0), 0), stripe(s_sib[range(w, n + w), range(n), :], n, w, (0, 1))) il[:, -1] = semiring.mul(stripe(s_c, n, 1, (w, w)), stripe(s_c, n, 1, (0, w - 1))).squeeze(1) il = semiring.sum(il, 1) s_i.diagonal(-w).copy_( semiring.mul(il, s_arc.diagonal(-w).movedim(-1, 0)).movedim(0, -1)) # INCOMPLETE-R: I(i->j) = <I(i->r), S(i->r, j)> * s(i->j), i < r < j # <C(i->i), C(j->i+1)> * s(i->j), otherwise # [n, w, batch_size, ...] ir = semiring.times( stripe(s_i, n, w), stripe(s_s, n, w, (0, w), 0), stripe(s_sib[range(n), range(w, n + w), :], n, w)) if not self.multiroot: semiring.zero_(ir[0]) ir[:, 0] = semiring.mul(stripe(s_c, n, 1), stripe(s_c, n, 1, (w, 1))).squeeze(1) ir = semiring.sum(ir, 1) s_i.diagonal(w).copy_( semiring.mul(ir, s_arc.diagonal(w).movedim(-1, 0)).movedim(0, -1)) # [batch_size, ..., n] sl = sr = semiring.dot(stripe(s_c, n, w), stripe(s_c, n, w, (w, 1)), 1).movedim(0, -1) # SIB: S(j, i) = <C(i->r), C(j->r+1)>, i <= r < j s_s.diagonal(-w).copy_(sl) # SIB: S(i, j) = <C(i->r), C(j->r+1)>, i <= r < j s_s.diagonal(w).copy_(sr) # [n, batch_size, ...] # COMPLETE-L: C(j->i) = <C(r->i), I(j->r)>, i <= r < j cl = semiring.dot(stripe(s_c, n, w, (0, 0), 0), stripe(s_i, n, w, (w, 0)), 1) s_c.diagonal(-w).copy_(cl.movedim(0, -1)) # COMPLETE-R: C(i->j) = <I(i->r), C(r->j)>, i < r <= j cr = semiring.dot(stripe(s_i, n, w, (0, 1)), stripe(s_c, n, w, (1, w), 0), 1) s_c.diagonal(w).copy_(cr.movedim(0, -1)) return semiring.unconvert(s_c)[0][self.lens, range(batch_size)]
def cky(scores, mask): r""" The implementation of `Cocke-Kasami-Younger`_ (CKY) algorithm to parse constituency trees :cite:`zhang-etal-2020-fast`. Args: scores (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. Scores of all candidate constituents. mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``. The mask to avoid parsing over padding tokens. For each square matrix in a batch, the positions except upper triangular part should be masked out. Returns: Sequences of factorized predicted bracketed trees that are traversed in pre-order. Examples: >>> scores = torch.tensor([[[ 2.5659, 1.4253, -2.5272, 3.3011], [ 1.3687, -0.5869, 1.0011, 3.3020], [ 1.2297, 0.4862, 1.1975, 2.5387], [-0.0511, -1.2541, -0.7577, 0.2659]]]) >>> mask = torch.tensor([[[False, True, True, True], [False, False, True, True], [False, False, False, True], [False, False, False, False]]]) >>> cky(scores, mask) [[(0, 3), (0, 1), (1, 3), (1, 2), (2, 3)]] .. _Cocke-Kasami-Younger: https://en.wikipedia.org/wiki/CYK_algorithm """ lens = mask[:, 0].sum(-1) scores = scores.permute(1, 2, 3, 0) seq_len, seq_len, n_labels, batch_size = scores.shape s = scores.new_zeros(seq_len, seq_len, batch_size) p_s = scores.new_zeros(seq_len, seq_len, batch_size).long() p_l = scores.new_zeros(seq_len, seq_len, batch_size).long() for w in range(1, seq_len): n = seq_len - w starts = p_s.new_tensor(range(n)).unsqueeze(0) s_l, p = scores.diagonal(w).max(0) p_l.diagonal(w).copy_(p) if w == 1: s.diagonal(w).copy_(s_l) continue # [n, w, batch_size] s_s = stripe(s, n, w - 1, (0, 1)) + stripe(s, n, w - 1, (1, w), 0) # [batch_size, n, w] s_s = s_s.permute(2, 0, 1) # [batch_size, n] s_s, p = s_s.max(-1) s.diagonal(w).copy_(s_s + s_l) p_s.diagonal(w).copy_(p + starts + 1) def backtrack(p_s, p_l, i, j): if j == i + 1: return [(i, j, p_l[i][j])] split, label = p_s[i][j], p_l[i][j] ltree = backtrack(p_s, p_l, i, split) rtree = backtrack(p_s, p_l, split, j) return [(i, j, label)] + ltree + rtree p_s = p_s.permute(2, 0, 1).tolist() p_l = p_l.permute(2, 0, 1).tolist() trees = [ backtrack(p_s[i], p_l[i], 0, length) for i, length in enumerate(lens.tolist()) ] return trees
def eisner2o(scores, mask): """ Second-order Eisner algorithm for projective decoding. This is an extension of the first-order one and further incorporates sibling scores into tree scoring. References: - Ryan McDonald and Fernando Pereira (EACL'06) Online Learning of Approximate Dependency Parsing Algorithms https://www.aclweb.org/anthology/E06-1011/ Args: scores (tuple[torch.Tensor, torch.Tensor]): A tuple of two tensors representing the first-order and second-order scores repectively. The first ([batch_size, seq_len, seq_len]) holds scores of dependent-head pairs. The second ([batch_size, seq_len, seq_len, seq_len]) holds scores of the dependent-head-sibling triples. mask (torch.BoolTensor): [batch_size, seq_len] Mask to avoid parsing over padding tokens. The first column with pseudo words as roots should be set to False. Returns: Tensor: [batch_size, seq_len] Projective parse trees. """ # the end position of each sentence in a batch lens = mask.sum(1) s_arc, s_sib = scores batch_size, seq_len, _ = s_arc.shape # [seq_len, seq_len, batch_size] s_arc = s_arc.permute(2, 1, 0) # [seq_len, seq_len, seq_len, batch_size] s_sib = s_sib.permute(2, 1, 3, 0) s_i = torch.full_like(s_arc, float('-inf')) s_s = torch.full_like(s_arc, float('-inf')) s_c = torch.full_like(s_arc, float('-inf')) p_i = s_arc.new_zeros(seq_len, seq_len, batch_size).long() p_s = s_arc.new_zeros(seq_len, seq_len, batch_size).long() p_c = s_arc.new_zeros(seq_len, seq_len, batch_size).long() s_c.diagonal().fill_(0) for w in range(1, seq_len): # n denotes the number of spans to iterate, # from span (0, w) to span (n, n+w) given width w n = seq_len - w starts = p_i.new_tensor(range(n)).unsqueeze(0) # I(j->i) = max(I(j->r) + S(j->r, i)), i < r < j | # C(j->j) + C(i->j-1)) # + s(j->i) # [n, w, batch_size] il = stripe(s_i, n, w, (w, 1)) + stripe(s_s, n, w, (1, 0), 0) il += stripe(s_sib[range(w, n+w), range(n)], n, w, (0, 1)) # [n, 1, batch_size] il0 = stripe(s_c, n, 1, (w, w)) + stripe(s_c, n, 1, (0, w - 1)) # il0[0] are set to zeros since the scores of the complete spans starting from 0 are always -inf il[:, -1] = il0.index_fill_(0, lens.new_tensor(0), 0).squeeze(1) il_span, il_path = il.permute(2, 0, 1).max(-1) s_i.diagonal(-w).copy_(il_span + s_arc.diagonal(-w)) p_i.diagonal(-w).copy_(il_path + starts + 1) # I(i->j) = max(I(i->r) + S(i->r, j), i < r < j | # C(i->i) + C(j->i+1)) # + s(i->j) # [n, w, batch_size] ir = stripe(s_i, n, w) + stripe(s_s, n, w, (0, w), 0) ir += stripe(s_sib[range(n), range(w, n+w)], n, w) ir[0] = float('-inf') # [n, 1, batch_size] ir0 = stripe(s_c, n, 1) + stripe(s_c, n, 1, (w, 1)) ir[:, 0] = ir0.squeeze(1) ir_span, ir_path = ir.permute(2, 0, 1).max(-1) s_i.diagonal(w).copy_(ir_span + s_arc.diagonal(w)) p_i.diagonal(w).copy_(ir_path + starts) # [n, w, batch_size] slr = stripe(s_c, n, w) + stripe(s_c, n, w, (w, 1)) slr_span, slr_path = slr.permute(2, 0, 1).max(-1) # S(j, i) = max(C(i->r) + C(j->r+1)), i <= r < j s_s.diagonal(-w).copy_(slr_span) p_s.diagonal(-w).copy_(slr_path + starts) # S(i, j) = max(C(i->r) + C(j->r+1)), i <= r < j s_s.diagonal(w).copy_(slr_span) p_s.diagonal(w).copy_(slr_path + starts) # C(j->i) = max(C(r->i) + I(j->r)), i <= r < j cl = stripe(s_c, n, w, (0, 0), 0) + stripe(s_i, n, w, (w, 0)) cl_span, cl_path = cl.permute(2, 0, 1).max(-1) s_c.diagonal(-w).copy_(cl_span) p_c.diagonal(-w).copy_(cl_path + starts) # C(i->j) = max(I(i->r) + C(r->j)), i < r <= j cr = stripe(s_i, n, w, (0, 1)) + stripe(s_c, n, w, (1, w), 0) cr_span, cr_path = cr.permute(2, 0, 1).max(-1) s_c.diagonal(w).copy_(cr_span) # disable multi words to modify the root s_c[0, w][lens.ne(w)] = float('-inf') p_c.diagonal(w).copy_(cr_path + starts + 1) def backtrack(p_i, p_s, p_c, heads, i, j, flag): if i == j: return if flag == 'c': r = p_c[i, j] backtrack(p_i, p_s, p_c, heads, i, r, 'i') backtrack(p_i, p_s, p_c, heads, r, j, 'c') elif flag == 's': r = p_s[i, j] i, j = sorted((i, j)) backtrack(p_i, p_s, p_c, heads, i, r, 'c') backtrack(p_i, p_s, p_c, heads, j, r + 1, 'c') elif flag == 'i': r, heads[j] = p_i[i, j], i if r == i: r = i + 1 if i < j else i - 1 backtrack(p_i, p_s, p_c, heads, j, r, 'c') else: backtrack(p_i, p_s, p_c, heads, i, r, 'i') backtrack(p_i, p_s, p_c, heads, r, j, 's') preds = [] p_i = p_i.permute(2, 0, 1).cpu() p_s = p_s.permute(2, 0, 1).cpu() p_c = p_c.permute(2, 0, 1).cpu() for i, length in enumerate(lens.tolist()): heads = p_c.new_zeros(length + 1, dtype=torch.long) backtrack(p_i[i], p_s[i], p_c[i], heads, 0, length, 'c') preds.append(heads.to(mask.device)) return pad(preds, total_length=seq_len).to(mask.device)
def eisner(scores, mask): """ First-order Eisner algorithm for projective decoding. References: - Ryan McDonald, Koby Crammer and Fernando Pereira (ACL'05) Online Large-Margin Training of Dependency Parsers https://www.aclweb.org/anthology/P05-1012/ Args: scores (torch.Tensor): [batch_size, seq_len, seq_len] The scores of dependent-head pairs. mask (torch.BoolTensor): [batch_size, seq_len] Mask to avoid parsing over padding tokens. The first column with pseudo words as roots should be set to False. Returns: Tensor: [batch_size, seq_len] Projective parse trees. """ lens = mask.sum(1) batch_size, seq_len, _ = scores.shape scores = scores.permute(2, 1, 0) s_i = torch.full_like(scores, float('-inf')) s_c = torch.full_like(scores, float('-inf')) p_i = scores.new_zeros(seq_len, seq_len, batch_size).long() p_c = scores.new_zeros(seq_len, seq_len, batch_size).long() s_c.diagonal().fill_(0) for w in range(1, seq_len): n = seq_len - w starts = p_i.new_tensor(range(n)).unsqueeze(0) # ilr = C(i->r) + C(j->r+1) ilr = stripe(s_c, n, w) + stripe(s_c, n, w, (w, 1)) # [batch_size, n, w] il = ir = ilr.permute(2, 0, 1) # I(j->i) = max(C(i->r) + C(j->r+1) + s(j->i)), i <= r < j il_span, il_path = il.max(-1) s_i.diagonal(-w).copy_(il_span + scores.diagonal(-w)) p_i.diagonal(-w).copy_(il_path + starts) # I(i->j) = max(C(i->r) + C(j->r+1) + s(i->j)), i <= r < j ir_span, ir_path = ir.max(-1) s_i.diagonal(w).copy_(ir_span + scores.diagonal(w)) p_i.diagonal(w).copy_(ir_path + starts) # C(j->i) = max(C(r->i) + I(j->r)), i <= r < j cl = stripe(s_c, n, w, (0, 0), 0) + stripe(s_i, n, w, (w, 0)) cl_span, cl_path = cl.permute(2, 0, 1).max(-1) s_c.diagonal(-w).copy_(cl_span) p_c.diagonal(-w).copy_(cl_path + starts) # C(i->j) = max(I(i->r) + C(r->j)), i < r <= j cr = stripe(s_i, n, w, (0, 1)) + stripe(s_c, n, w, (1, w), 0) cr_span, cr_path = cr.permute(2, 0, 1).max(-1) s_c.diagonal(w).copy_(cr_span) s_c[0, w][lens.ne(w)] = float('-inf') p_c.diagonal(w).copy_(cr_path + starts + 1) def backtrack(p_i, p_c, heads, i, j, complete): if i == j: return if complete: r = p_c[i, j] backtrack(p_i, p_c, heads, i, r, False) backtrack(p_i, p_c, heads, r, j, True) else: r, heads[j] = p_i[i, j], i i, j = sorted((i, j)) backtrack(p_i, p_c, heads, i, r, True) backtrack(p_i, p_c, heads, j, r + 1, True) preds = [] p_c = p_c.permute(2, 0, 1).cpu() p_i = p_i.permute(2, 0, 1).cpu() for i, length in enumerate(lens.tolist()): heads = p_c.new_zeros(length + 1, dtype=torch.long) backtrack(p_i[i], p_c[i], heads, 0, length, True) preds.append(heads.to(mask.device)) return pad(preds, total_length=seq_len).to(mask.device)
def inside(self, scores, mask, cands=None): # the end position of each sentence in a batch lens = mask.sum(1) s_arc, s_sib = scores batch_size, seq_len, _ = s_arc.shape # [seq_len, seq_len, batch_size] s_arc = s_arc.permute(2, 1, 0) # [seq_len, seq_len, seq_len, batch_size] s_sib = s_sib.permute(2, 1, 3, 0) s_i = torch.full_like(s_arc, float('-inf')) s_s = torch.full_like(s_arc, float('-inf')) s_c = torch.full_like(s_arc, float('-inf')) s_c.diagonal().fill_(0) # set the scores of arcs excluded by cands to -inf if cands is not None: mask = mask.index_fill(1, lens.new_tensor(0), 1) mask = (mask.unsqueeze(1) & mask.unsqueeze(-1)).permute(2, 1, 0) cands = cands.unsqueeze(-1).index_fill(1, lens.new_tensor(0), -1) cands = cands.eq(lens.new_tensor(range(seq_len))) | cands.lt(0) cands = cands.permute(2, 1, 0) & mask s_arc = s_arc.masked_fill(~cands, float('-inf')) for w in range(1, seq_len): # n denotes the number of spans to iterate, # from span (0, w) to span (n, n+w) given width w n = seq_len - w # I(j->i) = logsum(exp(I(j->r) + S(j->r, i)) +, i < r < j # exp(C(j->j) + C(i->j-1))) # + s(j->i) # [n, w, batch_size] il = stripe(s_i, n, w, (w, 1)) + stripe(s_s, n, w, (1, 0), 0) il += stripe(s_sib[range(w, n + w), range(n)], n, w, (0, 1)) # [n, 1, batch_size] il0 = stripe(s_c, n, 1, (w, w)) + stripe(s_c, n, 1, (0, w - 1)) # il0[0] are set to zeros since the scores of the complete spans starting from 0 are always -inf il[:, -1] = il0.index_fill_(0, lens.new_tensor(0), 0).squeeze(1) if il.requires_grad: il.register_hook(lambda x: x.masked_fill_(torch.isnan(x), 0)) il = il.permute(2, 0, 1).logsumexp(-1) s_i.diagonal(-w).copy_(il + s_arc.diagonal(-w)) # I(i->j) = logsum(exp(I(i->r) + S(i->r, j)) +, i < r < j # exp(C(i->i) + C(j->i+1))) # + s(i->j) # [n, w, batch_size] ir = stripe(s_i, n, w) + stripe(s_s, n, w, (0, w), 0) ir += stripe(s_sib[range(n), range(w, n + w)], n, w) ir[0] = float('-inf') # [n, 1, batch_size] ir0 = stripe(s_c, n, 1) + stripe(s_c, n, 1, (w, 1)) ir[:, 0] = ir0.squeeze(1) if ir.requires_grad: ir.register_hook(lambda x: x.masked_fill_(torch.isnan(x), 0)) ir = ir.permute(2, 0, 1).logsumexp(-1) s_i.diagonal(w).copy_(ir + s_arc.diagonal(w)) # [n, w, batch_size] slr = stripe(s_c, n, w) + stripe(s_c, n, w, (w, 1)) if slr.requires_grad: slr.register_hook(lambda x: x.masked_fill_(torch.isnan(x), 0)) slr = slr.permute(2, 0, 1).logsumexp(-1) # S(j, i) = logsumexp(C(i->r) + C(j->r+1)), i <= r < j s_s.diagonal(-w).copy_(slr) # S(i, j) = logsumexp(C(i->r) + C(j->r+1)), i <= r < j s_s.diagonal(w).copy_(slr) # C(j->i) = logsumexp(C(r->i) + I(j->r)), i <= r < j cl = stripe(s_c, n, w, (0, 0), 0) + stripe(s_i, n, w, (w, 0)) cl.register_hook(lambda x: x.masked_fill_(torch.isnan(x), 0)) s_c.diagonal(-w).copy_(cl.permute(2, 0, 1).logsumexp(-1)) # C(i->j) = logsumexp(I(i->r) + C(r->j)), i < r <= j cr = stripe(s_i, n, w, (0, 1)) + stripe(s_c, n, w, (1, w), 0) cr.register_hook(lambda x: x.masked_fill_(torch.isnan(x), 0)) s_c.diagonal(w).copy_(cr.permute(2, 0, 1).logsumexp(-1)) # disable multi words to modify the root s_c[0, w][lens.ne(w)] = float('-inf') return s_c[0].gather(0, lens.unsqueeze(0)).sum()