def mst(scores, mask, multiroot=False): """ MST algorithm for decoding non-pojective trees. This is a wrapper for ChuLiu/Edmonds algorithm. The algorithm first runs ChuLiu/Edmonds to parse a tree and then have a check of multi-roots, If ``multiroot=True`` and there indeed exist multi-roots, the algorithm seeks to find best single-root trees by iterating all possible single-root trees parsed by ChuLiu/Edmonds. Otherwise the resulting trees are directly taken as the final outputs. Args: scores (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. Scores of all dependent-head pairs. mask (~torch.BoolTensor): ``[batch_size, seq_len]``. Mask to avoid parsing over padding tokens. The first column serving as pseudo words for roots should be ``False``. muliroot (bool): Ensures to parse a single-root tree If ``False``. Returns: ~torch.Tensor: A tensor with shape ``[batch_size, seq_len]`` for the resulting non-projective parse trees. Examples: >>> scores = torch.tensor([[[-11.9436, -13.1464, -6.4789, -13.8917], [-60.6957, -60.2866, -48.6457, -63.8125], [-38.1747, -49.9296, -45.2733, -49.5571], [-19.7504, -23.9066, -9.9139, -16.2088]]]) >>> scores[:, 0, 1:] = float('-inf') >>> scores.diagonal(0, 1, 2)[1:].fill_(float('-inf')) >>> mask = torch.tensor([[False, True, True, True]]) >>> mst(scores, mask) tensor([[0, 2, 0, 2]]) """ batch_size, seq_len, _ = scores.shape scores = scores.cpu().unbind() preds = [] for i, length in enumerate(mask.sum(1).tolist()): s = scores[i][:length+1, :length+1] tree = chuliu_edmonds(s) roots = torch.where(tree[1:].eq(0))[0] + 1 if not multiroot and len(roots) > 1: s_root = s[:, 0] s_best = float('-inf') s = s.index_fill(1, torch.tensor(0), float('-inf')) for root in roots: s[:, 0] = float('-inf') s[root, 0] = s_root[root] t = chuliu_edmonds(s) s_tree = s[1:].gather(1, t[1:].unsqueeze(-1)).sum() if s_tree > s_best: s_best, tree = s_tree, t preds.append(tree) return pad(preds, total_length=seq_len).to(mask.device)
def compose(self, sequences): """ Composes a batch of sequences into a padded tensor. Args: sequences (list[~torch.Tensor]): A list of tensors. Returns: A padded tensor converted to proper device. """ return pad(sequences, self.pad_index).to(self.device)
def transform(self, sequences): sequences = [[self.preprocess(token) for token in seq] for seq in sequences] if self.fix_len <= 0: self.fix_len = max( len(token) for seq in sequences for token in seq) if self.use_vocab: sequences = [[[self.vocab[i] for i in token] if token else [self.unk] for token in seq] for seq in sequences] if self.bos: sequences = [[[self.bos_index]] + seq for seq in sequences] if self.eos: sequences = [seq + [[self.eos_index]] for seq in sequences] lens = [ min(self.fix_len, max(len(ids) for ids in seq)) for seq in sequences ] sequences = [ pad([torch.tensor(ids[:i]) for ids in seq], self.pad_index, i) for i, seq in zip(lens, sequences) ] return sequences
def eisner2o(scores, mask): """ Second-order Eisner algorithm for projective decoding. This is an extension of the first-order one that further incorporates sibling scores into tree scoring. References: - Ryan McDonald and Fernando Pereira. 2006. `Online Learning of Approximate Dependency Parsing Algorithms`_. Args: scores (~torch.Tensor, ~torch.Tensor): A tuple of two tensors representing the first-order and second-order scores repectively. The first (``[batch_size, seq_len, seq_len]``) holds scores of all dependent-head pairs. The second (``[batch_size, seq_len, seq_len, seq_len]``) holds scores of all dependent-head-sibling triples. mask (~torch.BoolTensor): ``[batch_size, seq_len]``. Mask to avoid parsing over padding tokens. The first column serving as pseudo words for roots should be ``False``. Returns: ~torch.Tensor: A tensor with shape ``[batch_size, seq_len]`` for the resulting projective parse trees. Examples: >>> s_arc = torch.tensor([[[ -2.8092, -7.9104, -0.9414, -5.4360], [-10.3494, -7.9298, -3.6929, -7.3985], [ 1.1815, -3.8291, 2.3166, -2.7183], [ -3.9776, -3.9063, -1.6762, -3.1861]]]) >>> s_sib = torch.tensor([[[[ 0.4719, 0.4154, 1.1333, 0.6946], [ 1.1252, 1.3043, 2.1128, 1.4621], [ 0.5974, 0.5635, 1.0115, 0.7550], [ 1.1174, 1.3794, 2.2567, 1.4043]], [[-2.1480, -4.1830, -2.5519, -1.8020], [-1.2496, -1.7859, -0.0665, -0.4938], [-2.6171, -4.0142, -2.9428, -2.2121], [-0.5166, -1.0925, 0.5190, 0.1371]], [[ 0.5827, -1.2499, -0.0648, -0.0497], [ 1.4695, 0.3522, 1.5614, 1.0236], [ 0.4647, -0.7996, -0.3801, 0.0046], [ 1.5611, 0.3875, 1.8285, 1.0766]], [[-1.3053, -2.9423, -1.5779, -1.2142], [-0.1908, -0.9699, 0.3085, 0.1061], [-1.6783, -2.8199, -1.8853, -1.5653], [ 0.3629, -0.3488, 0.9011, 0.5674]]]]) >>> mask = torch.tensor([[False, True, True, True]]) >>> eisner2o((s_arc, s_sib), mask) tensor([[0, 2, 0, 2]]) .. _Online Learning of Approximate Dependency Parsing Algorithms: https://www.aclweb.org/anthology/E06-1011/ """ # the end position of each sentence in a batch lens = mask.sum(1) s_arc, s_sib = scores batch_size, seq_len, _ = s_arc.shape # [seq_len, seq_len, batch_size] s_arc = s_arc.permute(2, 1, 0) # [seq_len, seq_len, seq_len, batch_size] s_sib = s_sib.permute(2, 1, 3, 0) s_i = torch.full_like(s_arc, float('-inf')) s_s = torch.full_like(s_arc, float('-inf')) s_c = torch.full_like(s_arc, float('-inf')) p_i = s_arc.new_zeros(seq_len, seq_len, batch_size).long() p_s = s_arc.new_zeros(seq_len, seq_len, batch_size).long() p_c = s_arc.new_zeros(seq_len, seq_len, batch_size).long() s_c.diagonal().fill_(0) for w in range(1, seq_len): # n denotes the number of spans to iterate, # from span (0, w) to span (n, n+w) given width w n = seq_len - w starts = p_i.new_tensor(range(n)).unsqueeze(0) # I(j->i) = max(I(j->r) + S(j->r, i)), i < r < j | # C(j->j) + C(i->j-1)) # + s(j->i) # [n, w, batch_size] il = stripe(s_i, n, w, (w, 1)) + stripe(s_s, n, w, (1, 0), 0) il += stripe(s_sib[range(w, n+w), range(n)], n, w, (0, 1)) # [n, 1, batch_size] il0 = stripe(s_c, n, 1, (w, w)) + stripe(s_c, n, 1, (0, w - 1)) # il0[0] are set to zeros since the scores of the complete spans starting from 0 are always -inf il[:, -1] = il0.index_fill_(0, lens.new_tensor(0), 0).squeeze(1) il_span, il_path = il.permute(2, 0, 1).max(-1) s_i.diagonal(-w).copy_(il_span + s_arc.diagonal(-w)) p_i.diagonal(-w).copy_(il_path + starts + 1) # I(i->j) = max(I(i->r) + S(i->r, j), i < r < j | # C(i->i) + C(j->i+1)) # + s(i->j) # [n, w, batch_size] ir = stripe(s_i, n, w) + stripe(s_s, n, w, (0, w), 0) ir += stripe(s_sib[range(n), range(w, n+w)], n, w) ir[0] = float('-inf') # [n, 1, batch_size] ir0 = stripe(s_c, n, 1) + stripe(s_c, n, 1, (w, 1)) ir[:, 0] = ir0.squeeze(1) ir_span, ir_path = ir.permute(2, 0, 1).max(-1) s_i.diagonal(w).copy_(ir_span + s_arc.diagonal(w)) p_i.diagonal(w).copy_(ir_path + starts) # [n, w, batch_size] slr = stripe(s_c, n, w) + stripe(s_c, n, w, (w, 1)) slr_span, slr_path = slr.permute(2, 0, 1).max(-1) # S(j, i) = max(C(i->r) + C(j->r+1)), i <= r < j s_s.diagonal(-w).copy_(slr_span) p_s.diagonal(-w).copy_(slr_path + starts) # S(i, j) = max(C(i->r) + C(j->r+1)), i <= r < j s_s.diagonal(w).copy_(slr_span) p_s.diagonal(w).copy_(slr_path + starts) # C(j->i) = max(C(r->i) + I(j->r)), i <= r < j cl = stripe(s_c, n, w, (0, 0), 0) + stripe(s_i, n, w, (w, 0)) cl_span, cl_path = cl.permute(2, 0, 1).max(-1) s_c.diagonal(-w).copy_(cl_span) p_c.diagonal(-w).copy_(cl_path + starts) # C(i->j) = max(I(i->r) + C(r->j)), i < r <= j cr = stripe(s_i, n, w, (0, 1)) + stripe(s_c, n, w, (1, w), 0) cr_span, cr_path = cr.permute(2, 0, 1).max(-1) s_c.diagonal(w).copy_(cr_span) # disable multi words to modify the root s_c[0, w][lens.ne(w)] = float('-inf') p_c.diagonal(w).copy_(cr_path + starts + 1) def backtrack(p_i, p_s, p_c, heads, i, j, flag): if i == j: return if flag == 'c': r = p_c[i, j] backtrack(p_i, p_s, p_c, heads, i, r, 'i') backtrack(p_i, p_s, p_c, heads, r, j, 'c') elif flag == 's': r = p_s[i, j] i, j = sorted((i, j)) backtrack(p_i, p_s, p_c, heads, i, r, 'c') backtrack(p_i, p_s, p_c, heads, j, r + 1, 'c') elif flag == 'i': r, heads[j] = p_i[i, j], i if r == i: r = i + 1 if i < j else i - 1 backtrack(p_i, p_s, p_c, heads, j, r, 'c') else: backtrack(p_i, p_s, p_c, heads, i, r, 'i') backtrack(p_i, p_s, p_c, heads, r, j, 's') preds = [] p_i = p_i.permute(2, 0, 1).cpu() p_s = p_s.permute(2, 0, 1).cpu() p_c = p_c.permute(2, 0, 1).cpu() for i, length in enumerate(lens.tolist()): heads = p_c.new_zeros(length + 1, dtype=torch.long) backtrack(p_i[i], p_s[i], p_c[i], heads, 0, length, 'c') preds.append(heads.to(mask.device)) return pad(preds, total_length=seq_len).to(mask.device)
def eisner(scores, mask): """ First-order Eisner algorithm for projective decoding. References: - Ryan McDonald, Koby Crammer and Fernando Pereira. 2005. `Online Large-Margin Training of Dependency Parsers`_. Args: scores (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. Scores of all dependent-head pairs. mask (~torch.BoolTensor): ``[batch_size, seq_len]``. Mask to avoid parsing over padding tokens. The first column serving as pseudo words for roots should be ``False``. Returns: ~torch.Tensor: A tensor with shape ``[batch_size, seq_len]`` for the resulting projective parse trees. Examples: >>> scores = torch.tensor([[[-13.5026, -18.3700, -13.0033, -16.6809], [-36.5235, -28.6344, -28.4696, -31.6750], [ -2.9084, -7.4825, -1.4861, -6.8709], [-29.4880, -27.6905, -26.1498, -27.0233]]]) >>> mask = torch.tensor([[False, True, True, True]]) >>> eisner(scores, mask) tensor([[0, 2, 0, 2]]) .. _Online Large-Margin Training of Dependency Parsers: https://www.aclweb.org/anthology/P05-1012/ """ lens = mask.sum(1) batch_size, seq_len, _ = scores.shape scores = scores.permute(2, 1, 0) s_i = torch.full_like(scores, float('-inf')) s_c = torch.full_like(scores, float('-inf')) p_i = scores.new_zeros(seq_len, seq_len, batch_size).long() p_c = scores.new_zeros(seq_len, seq_len, batch_size).long() s_c.diagonal().fill_(0) for w in range(1, seq_len): n = seq_len - w starts = p_i.new_tensor(range(n)).unsqueeze(0) # ilr = C(i->r) + C(j->r+1) ilr = stripe(s_c, n, w) + stripe(s_c, n, w, (w, 1)) # [batch_size, n, w] il = ir = ilr.permute(2, 0, 1) # I(j->i) = max(C(i->r) + C(j->r+1) + s(j->i)), i <= r < j il_span, il_path = il.max(-1) s_i.diagonal(-w).copy_(il_span + scores.diagonal(-w)) p_i.diagonal(-w).copy_(il_path + starts) # I(i->j) = max(C(i->r) + C(j->r+1) + s(i->j)), i <= r < j ir_span, ir_path = ir.max(-1) s_i.diagonal(w).copy_(ir_span + scores.diagonal(w)) p_i.diagonal(w).copy_(ir_path + starts) # C(j->i) = max(C(r->i) + I(j->r)), i <= r < j cl = stripe(s_c, n, w, (0, 0), 0) + stripe(s_i, n, w, (w, 0)) cl_span, cl_path = cl.permute(2, 0, 1).max(-1) s_c.diagonal(-w).copy_(cl_span) p_c.diagonal(-w).copy_(cl_path + starts) # C(i->j) = max(I(i->r) + C(r->j)), i < r <= j cr = stripe(s_i, n, w, (0, 1)) + stripe(s_c, n, w, (1, w), 0) cr_span, cr_path = cr.permute(2, 0, 1).max(-1) s_c.diagonal(w).copy_(cr_span) s_c[0, w][lens.ne(w)] = float('-inf') p_c.diagonal(w).copy_(cr_path + starts + 1) def backtrack(p_i, p_c, heads, i, j, complete): if i == j: return if complete: r = p_c[i, j] backtrack(p_i, p_c, heads, i, r, False) backtrack(p_i, p_c, heads, r, j, True) else: r, heads[j] = p_i[i, j], i i, j = sorted((i, j)) backtrack(p_i, p_c, heads, i, r, True) backtrack(p_i, p_c, heads, j, r + 1, True) preds = [] p_c = p_c.permute(2, 0, 1).cpu() p_i = p_i.permute(2, 0, 1).cpu() for i, length in enumerate(lens.tolist()): heads = p_c.new_zeros(length + 1, dtype=torch.long) backtrack(p_i[i], p_c[i], heads, 0, length, True) preds.append(heads.to(mask.device)) return pad(preds, total_length=seq_len).to(mask.device)
def compose(self, sequences): return [pad(i).to(self.device) for i in zip(*sequences)]