def decode(self, s_arc, s_rel, mask, tree=False, proj=False): r""" Args: s_arc (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. Scores of all possible arcs. s_rel (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``. Scores of all possible labels on each arc. mask (~torch.BoolTensor): ``[batch_size, seq_len]``. The mask for covering the unpadded tokens. tree (bool): If ``True``, ensures to output well-formed trees. Default: ``False``. proj (bool): If ``True``, ensures to output projective trees. Default: ``False``. Returns: ~torch.LongTensor, ~torch.LongTensor: Predicted arcs and labels of shape ``[batch_size, seq_len]``. """ lens = mask.sum(1) arc_preds = s_arc.argmax(-1) bad = [ not CoNLL.istree(seq[1:i + 1], proj) for i, seq in zip(lens.tolist(), arc_preds.tolist()) ] if tree and any(bad): alg = eisner if proj else mst arc_preds[bad] = alg(s_arc[bad], mask[bad]) rel_preds = s_rel.argmax(-1).gather( -1, arc_preds.unsqueeze(-1)).squeeze(-1) return arc_preds, rel_preds
def enumerate(self, semiring): trees = [] for i, length in enumerate(self.lens.tolist()): trees.append([]) for seq in itertools.product(range(length + 1), repeat=length): if not CoNLL.istree(list(seq), True, self.multiroot): continue trees[-1].append(semiring.prod(self.scores[i, range(1, length + 1), seq], -1)) return [torch.stack(seq) for seq in trees]
def decode(self, s_arc, s_sib, s_rel, mask, tree=False, mbr=True, proj=False): """ Args: s_arc (torch.Tensor): [batch_size, seq_len, seq_len] The scores of all possible arcs. s_sib (torch.Tensor): [batch_size, seq_len, seq_len, seq_len] The scores of all possible dependent-head-sibling triples. s_rel (torch.Tensor): [batch_size, seq_len, seq_len, n_labels] The scores of all possible labels on each arc. mask (torch.BoolTensor): [batch_size, seq_len, seq_len] Mask for covering the unpadded tokens. tree (bool): If True, ensures to output well-formed trees. Default: False. mbr (bool): If True, performs MBR decoding. Default: True. proj (bool): If True, ensures to output projective trees. Default: False. Returns: arc_preds (torch.Tensor): [batch_size, seq_len] The predicted arcs. rel_preds (torch.Tensor): [batch_size, seq_len] The predicted labels. """ lens = mask.sum(1) # prevent self-loops s_arc.diagonal(0, 1, 2).fill_(float('-inf')) arc_preds = s_arc.argmax(-1) bad = [ not CoNLL.istree(seq[1:i + 1], proj) for i, seq in zip(lens.tolist(), arc_preds.tolist()) ] if tree and any(bad): if proj and not mbr: arc_preds = eisner2o((s_arc, s_sib), mask) else: alg = eisner if proj else mst arc_preds[bad] = alg(s_arc[bad], mask[bad]) rel_preds = s_rel.argmax(-1).gather( -1, arc_preds.unsqueeze(-1)).squeeze(-1) return arc_preds, rel_preds
def enumerate(self, semiring): trees = [] for i, length in enumerate(self.lens.tolist()): trees.append([]) for seq in itertools.product(range(length + 1), repeat=length): if not CoNLL.istree(list(seq), True, self.multiroot): continue sibs = self.lens.new_tensor(CoNLL.get_sibs(seq)) sib_mask = sibs.gt(0) s_arc = self.scores[0][i, :length+1, :length+1] s_sib = self.scores[1][i, :length+1, :length+1, :length+1] s_arc = semiring.prod(s_arc[range(1, length + 1), seq], -1) s_sib = semiring.prod(s_sib[1:][sib_mask].gather(-1, sibs[sib_mask].unsqueeze(-1)).squeeze(-1)) trees[-1].append(semiring.mul(s_arc, s_sib)) return [torch.stack(seq) for seq in trees]
def decode(self, s_arc, s_sib, s_rel, mask, tree=False, mbr=True, proj=False): r""" Args: s_arc (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. Scores of all possible arcs. s_sib (~torch.Tensor): ``[batch_size, seq_len, seq_len, seq_len]``. Scores of all possible dependent-head-sibling triples. s_rel (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``. Scores of all possible labels on each arc. mask (~torch.BoolTensor): ``[batch_size, seq_len]``. The mask for covering the unpadded tokens. tree (bool): If ``True``, ensures to output well-formed trees. Default: ``False``. mbr (bool): If ``True``, performs MBR decoding. Default: ``True``. proj (bool): If ``True``, ensures to output projective trees. Default: ``False``. Returns: ~torch.LongTensor, ~torch.LongTensor: Predicted arcs and labels of shape ``[batch_size, seq_len]``. """ lens = mask.sum(1) arc_preds = s_arc.argmax(-1) bad = [ not CoNLL.istree(seq[1:i + 1], proj) for i, seq in zip(lens.tolist(), arc_preds.tolist()) ] if tree and any(bad): if proj and not mbr: arc_preds[bad] = Dependency2oCRF((s_arc[bad], s_sib[bad]), mask[bad].sum(-1)).argmax else: arc_preds[bad] = (DependencyCRF if proj else MatrixTree)( s_arc[bad], mask[bad].sum(-1)).argmax rel_preds = s_rel.argmax(-1).gather( -1, arc_preds.unsqueeze(-1)).squeeze(-1) return arc_preds, rel_preds