def step(self: Model, batch, batch_nb): result = self.forward(**batch) loss = result.loss parc = result.arc_logits prel = result.rel_logits mask: torch.Tensor = batch['word_attention_mask'] parc[:, 0, 1:] = float('-inf') parc.diagonal(0, 1, 2).fill_(float('-inf')) parc = eisner( parc, torch.cat([torch.zeros_like(mask[:, :1], dtype=torch.bool), mask], dim=1)) prel = torch.argmax(prel, dim=-1) prel = prel.gather(-1, parc.unsqueeze(-1)).squeeze(-1) arc_true = (parc[:, 1:] == batch['head'])[mask] rel_true = (prel[:, 1:] == batch['labels'])[mask] union_true = arc_true & rel_true return { loss_tag: loss.item(), f"{metric_tag}/arc_true": torch.sum(arc_true, dtype=torch.float).item(), f"{metric_tag}/uni_true": torch.sum(union_true, dtype=torch.float).item(), f"{metric_tag}/all": union_true.numel() }
def step(self, batch, preds, targets, mask) -> dict: parc, prel = preds rarc, rrel = targets parc[:, 0, 1:] = float('-inf') parc.diagonal(0, 1, 2).fill_(float('-inf')) parc = eisner( parc, torch.cat([torch.zeros_like(mask[:, :1], dtype=torch.bool), mask], dim=1)) prel = torch.argmax(prel, dim=-1) prel = prel.gather(-1, parc.unsqueeze(-1)).squeeze(-1) # 对 punc 不计算分数 punc_mask = rrel != self.punct_idx mask = mask & punc_mask arc_true = (parc == rarc)[mask] rel_true = (prel == rrel)[mask] union_true = arc_true & rel_true return { "arc_true": torch.sum(arc_true, dtype=torch.float).item(), "uni_true": torch.sum(union_true, dtype=torch.float).item(), "all": union_true.numel() }
def sdp(self, hidden: dict, graph=True): """ 语义依存图(树) Args: hidden: 分词时所得到的中间表示 graph: 选择是语义依存图还是语义依存树结果 Returns: 语义依存图(树)结果 """ word_attention_mask = hidden['word_cls_mask'] sdp_arc, sdp_label = self.model.sdp_classifier( input=hidden['word_cls_input'], word_attention_mask=word_attention_mask[:, 1:])[0] sdp_arc[:, 0, 1:] = float('-inf') sdp_arc.diagonal(0, 1, 2)[1:].fill_(float('-inf')) # 避免自指 sdp_label = torch.argmax(sdp_label, dim=-1) if graph: # 语义依存图 sdp_arc = torch.sigmoid_(sdp_arc) > 0.5 else: # 语义依存树 sdp_arc_idx = eisner( sdp_arc, word_attention_mask).unsqueeze_(-1).expand_as(sdp_arc) sdp_arc = torch.zeros_like(sdp_arc, dtype=torch.bool).scatter_( -1, sdp_arc_idx, True) sdp_arc[~word_attention_mask] = False sdp_label = get_graph_entities(sdp_arc, sdp_label, self.sdp_vocab) return sdp_label
def dep(self, hidden: dict, fast=True): """ 依存句法树 Args: hidden: 分词时所得到的中间表示 fast: 启用 fast 模式时,减少对结果的约束,速度更快,相应的精度会降低 Returns: 依存句法树结果 """ word_attention_mask = hidden['word_cls_mask'] dep_arc, dep_label = self.model.dep_classifier.forward( input=hidden['word_cls_input'], word_attention_mask=word_attention_mask[:, 1:])[0] dep_arc[:, 0, 1:] = float('-inf') dep_arc.diagonal(0, 1, 2)[1:].fill_(float('-inf')) dep_arc = dep_arc.argmax( dim=-1) if fast else eisner(dep_arc, word_attention_mask) dep_label = torch.argmax(dep_label, dim=-1) dep_label = dep_label.gather(-1, dep_arc.unsqueeze(-1)).squeeze(-1) dep_arc[~word_attention_mask] = -1 dep_label[~word_attention_mask] = -1 arc_pred = [[item for item in arcs if item != -1] for arcs in dep_arc[:, 1:].cpu().numpy().tolist()] rel_pred = [[self.dep_vocab[item] for item in rels if item != -1] for rels in dep_label[:, 1:].cpu().numpy().tolist()] return [[(idx + 1, arc, rel) for idx, (arc, rel) in enumerate(zip(arcs, rels))] for arcs, rels in zip(arc_pred, rel_pred)]
def step(self: pl.LightningModule, batch, batch_nb): loss, (parc, prel) = self(**batch) mask: torch.Tensor = batch['word_attention_mask'] parc[:, 0, 1:] = float('-inf') parc.diagonal(0, 1, 2)[1:].fill_(float('-inf')) parc = eisner( parc, torch.cat([torch.zeros_like(mask[:, :1], dtype=torch.bool), mask], dim=1)) prel = torch.argmax(prel, dim=-1) prel = prel.gather(-1, parc.unsqueeze(-1)).squeeze(-1) arc_true = (parc[:, 1:] == batch['head'])[mask] rel_true = (prel[:, 1:] == batch['labels'])[mask] union_true = arc_true & rel_true return { loss_tag: loss.item(), f"{metric_tag}/true": torch.sum(union_true, dtype=torch.float).item(), f"{metric_tag}/all": union_true.numel() }
def sdp(self, hidden: dict, mode: str = 'graph'): """ 语义依存图(树) Args: hidden: 分词时所得到的中间表示 mode: ['tree', 'graph', 'mix'] Returns: 语义依存图(树)结果 """ if len(self.sdp_vocab) == 0: return [] word_attention_mask = hidden['word_cls_mask'] result = self.model.sdp_classifier( input=hidden['word_cls_input'], word_attention_mask=word_attention_mask[:, 1:], is_processed=True) sdp_arc, sdp_label = result.arc_logits, result.rel_logits sdp_arc[:, 0, 1:] = float('-inf') sdp_arc.diagonal(0, 1, 2).fill_(float('-inf')) # 避免自指 sdp_label = torch.argmax(sdp_label, dim=-1) if mode == 'tree': # 语义依存树 sdp_arc_idx = eisner( sdp_arc, word_attention_mask).unsqueeze_(-1).expand_as(sdp_arc) sdp_arc_res = torch.zeros_like(sdp_arc, dtype=torch.bool).scatter_( -1, sdp_arc_idx, True) elif mode == 'mix': # 混合解码 sdp_arc_idx = eisner( sdp_arc, word_attention_mask).unsqueeze_(-1).expand_as(sdp_arc) sdp_arc_res = (sdp_arc.sigmoid_() > 0.5).scatter_( -1, sdp_arc_idx, True) else: # 语义依存图 sdp_arc_res = torch.sigmoid_(sdp_arc) > 0.5 sdp_arc_res[~word_attention_mask] = False sdp_label = get_graph_entities(sdp_arc_res, sdp_label, self.sdp_vocab) return sdp_label
def step(self, batch, preds, targets, mask) -> dict: seq_lens = torch.sum(mask, dim=-1, dtype=torch.int).cpu().numpy() parc, prel = preds rarc, rrel = targets mask = torch.cat( [torch.zeros_like(mask[:, :1], dtype=torch.bool), mask], dim=1) parc[:, 0, 1:] = float('-inf') parc.diagonal(0, 1, 2).fill_(float('-inf')) parc = eisner(parc, mask) # batch_size x max_len without root arange_index = torch.arange(1, mask.shape[1] + 1, dtype=torch.long, device=prel.device) \ .unsqueeze(0) \ .repeat(mask.shape[0], 1) app_masks = parc.ne(arange_index) app_masks = app_masks.unsqueeze(-1).unsqueeze(-1).repeat( 1, 1, 1, prel.shape[-1]) app_masks[:, :, :, 1:] = 0 prel = prel.masked_fill(app_masks, float("-inf")) prel = torch.argmax(prel, dim=-1) prel = prel.gather(-1, parc.unsqueeze(-1)).squeeze(-1) parc = parc[:, 1:] prel = prel[:, 1:] punc_mask = rrel == self.punct_idx rrel = rrel.masked_fill(punc_mask, -1) prel = prel.masked_fill(punc_mask, -1) rarc_entities, rrel_entities, rseg_entities = self.get_graph_entities( rarc, rrel, seq_lens) parc_entities, prel_entities, pseg_entities = self.get_graph_entities( parc, prel, seq_lens) return { 'uas_correct': len(parc_entities & rarc_entities), 'uas_pred': len(parc_entities), 'uas_true': len(rarc_entities), 'las_correct': len(prel_entities & rrel_entities), 'las_pred': len(prel_entities), 'las_true': len(rrel_entities), 'seg_correct': len(pseg_entities & rseg_entities), 'seg_pred': len(pseg_entities), 'seg_true': len(rseg_entities), }
def dep(self, hidden: dict, fast=True, as_tuple=True): """ 依存句法树 Args: hidden: 分词时所得到的中间表示 fast: 启用 fast 模式时,减少对结果的约束,速度更快,相应的精度会降低 as_tuple: 返回的结果是否为 (idx, head, rel) 的格式,否则返回 heads, rels Returns: 依存句法树结果 """ if len(self.dep_vocab) == 0: return [] word_attention_mask = hidden['word_cls_mask'] result = self.model.dep_classifier.forward( input=hidden['word_cls_input'], word_attention_mask=word_attention_mask[:, 1:], is_processed=True) dep_arc, dep_label = result.arc_logits, result.rel_logits dep_arc[:, 0, 1:] = float('-inf') dep_arc.diagonal(0, 1, 2).fill_(float('-inf')) dep_arc = dep_arc.argmax( dim=-1) if fast else eisner(dep_arc, word_attention_mask) dep_label = torch.argmax(dep_label, dim=-1) dep_label = dep_label.gather(-1, dep_arc.unsqueeze(-1)).squeeze(-1) dep_arc[~word_attention_mask] = -1 dep_label[~word_attention_mask] = -1 head_pred = [[item for item in arcs if item != -1] for arcs in dep_arc[:, 1:].cpu().numpy().tolist()] rel_pred = [[self.dep_vocab[item] for item in rels if item != -1] for rels in dep_label[:, 1:].cpu().numpy().tolist()] if not as_tuple: return head_pred, rel_pred return [[(idx + 1, head, rel) for idx, (head, rel) in enumerate(zip(heads, rels))] for heads, rels in zip(head_pred, rel_pred)]