Esempio n. 1
0
    def step(self: Model, batch, batch_nb):
        result = self.forward(**batch)
        loss = result.loss
        parc = result.arc_logits
        prel = result.rel_logits

        mask: torch.Tensor = batch['word_attention_mask']

        parc[:, 0, 1:] = float('-inf')
        parc.diagonal(0, 1, 2).fill_(float('-inf'))
        parc = eisner(
            parc,
            torch.cat([torch.zeros_like(mask[:, :1], dtype=torch.bool), mask],
                      dim=1))

        prel = torch.argmax(prel, dim=-1)
        prel = prel.gather(-1, parc.unsqueeze(-1)).squeeze(-1)

        arc_true = (parc[:, 1:] == batch['head'])[mask]
        rel_true = (prel[:, 1:] == batch['labels'])[mask]
        union_true = arc_true & rel_true

        return {
            loss_tag:
            loss.item(),
            f"{metric_tag}/arc_true":
            torch.sum(arc_true, dtype=torch.float).item(),
            f"{metric_tag}/uni_true":
            torch.sum(union_true, dtype=torch.float).item(),
            f"{metric_tag}/all":
            union_true.numel()
        }
Esempio n. 2
0
    def step(self, batch, preds, targets, mask) -> dict:
        parc, prel = preds
        rarc, rrel = targets

        parc[:, 0, 1:] = float('-inf')
        parc.diagonal(0, 1, 2).fill_(float('-inf'))
        parc = eisner(
            parc,
            torch.cat([torch.zeros_like(mask[:, :1], dtype=torch.bool), mask],
                      dim=1))

        prel = torch.argmax(prel, dim=-1)
        prel = prel.gather(-1, parc.unsqueeze(-1)).squeeze(-1)

        # 对 punc 不计算分数
        punc_mask = rrel != self.punct_idx
        mask = mask & punc_mask

        arc_true = (parc == rarc)[mask]
        rel_true = (prel == rrel)[mask]
        union_true = arc_true & rel_true

        return {
            "arc_true": torch.sum(arc_true, dtype=torch.float).item(),
            "uni_true": torch.sum(union_true, dtype=torch.float).item(),
            "all": union_true.numel()
        }
Esempio n. 3
0
    def sdp(self, hidden: dict, graph=True):
        """
        语义依存图(树)
        Args:
            hidden: 分词时所得到的中间表示
            graph: 选择是语义依存图还是语义依存树结果

        Returns:
            语义依存图(树)结果
        """
        word_attention_mask = hidden['word_cls_mask']
        sdp_arc, sdp_label = self.model.sdp_classifier(
            input=hidden['word_cls_input'],
            word_attention_mask=word_attention_mask[:, 1:])[0]
        sdp_arc[:, 0, 1:] = float('-inf')
        sdp_arc.diagonal(0, 1, 2)[1:].fill_(float('-inf'))  # 避免自指
        sdp_label = torch.argmax(sdp_label, dim=-1)

        if graph:
            # 语义依存图
            sdp_arc = torch.sigmoid_(sdp_arc) > 0.5
        else:
            # 语义依存树
            sdp_arc_idx = eisner(
                sdp_arc, word_attention_mask).unsqueeze_(-1).expand_as(sdp_arc)
            sdp_arc = torch.zeros_like(sdp_arc, dtype=torch.bool).scatter_(
                -1, sdp_arc_idx, True)
        sdp_arc[~word_attention_mask] = False
        sdp_label = get_graph_entities(sdp_arc, sdp_label, self.sdp_vocab)

        return sdp_label
Esempio n. 4
0
    def dep(self, hidden: dict, fast=True):
        """
        依存句法树
        Args:
            hidden: 分词时所得到的中间表示
            fast: 启用 fast 模式时,减少对结果的约束,速度更快,相应的精度会降低

        Returns:
            依存句法树结果
        """
        word_attention_mask = hidden['word_cls_mask']
        dep_arc, dep_label = self.model.dep_classifier.forward(
            input=hidden['word_cls_input'],
            word_attention_mask=word_attention_mask[:, 1:])[0]
        dep_arc[:, 0, 1:] = float('-inf')
        dep_arc.diagonal(0, 1, 2)[1:].fill_(float('-inf'))
        dep_arc = dep_arc.argmax(
            dim=-1) if fast else eisner(dep_arc, word_attention_mask)

        dep_label = torch.argmax(dep_label, dim=-1)
        dep_label = dep_label.gather(-1, dep_arc.unsqueeze(-1)).squeeze(-1)

        dep_arc[~word_attention_mask] = -1
        dep_label[~word_attention_mask] = -1

        arc_pred = [[item for item in arcs if item != -1]
                    for arcs in dep_arc[:, 1:].cpu().numpy().tolist()]
        rel_pred = [[self.dep_vocab[item] for item in rels if item != -1]
                    for rels in dep_label[:, 1:].cpu().numpy().tolist()]

        return [[(idx + 1, arc, rel)
                 for idx, (arc, rel) in enumerate(zip(arcs, rels))]
                for arcs, rels in zip(arc_pred, rel_pred)]
Esempio n. 5
0
    def step(self: pl.LightningModule, batch, batch_nb):
        loss, (parc, prel) = self(**batch)

        mask: torch.Tensor = batch['word_attention_mask']

        parc[:, 0, 1:] = float('-inf')
        parc.diagonal(0, 1, 2)[1:].fill_(float('-inf'))
        parc = eisner(
            parc,
            torch.cat([torch.zeros_like(mask[:, :1], dtype=torch.bool), mask],
                      dim=1))

        prel = torch.argmax(prel, dim=-1)
        prel = prel.gather(-1, parc.unsqueeze(-1)).squeeze(-1)

        arc_true = (parc[:, 1:] == batch['head'])[mask]
        rel_true = (prel[:, 1:] == batch['labels'])[mask]
        union_true = arc_true & rel_true

        return {
            loss_tag: loss.item(),
            f"{metric_tag}/true": torch.sum(union_true,
                                            dtype=torch.float).item(),
            f"{metric_tag}/all": union_true.numel()
        }
Esempio n. 6
0
    def sdp(self, hidden: dict, mode: str = 'graph'):
        """
        语义依存图(树)
        Args:
            hidden: 分词时所得到的中间表示
            mode: ['tree', 'graph', 'mix']

        Returns:
            语义依存图(树)结果
        """
        if len(self.sdp_vocab) == 0:
            return []

        word_attention_mask = hidden['word_cls_mask']
        result = self.model.sdp_classifier(
            input=hidden['word_cls_input'],
            word_attention_mask=word_attention_mask[:, 1:],
            is_processed=True)
        sdp_arc, sdp_label = result.arc_logits, result.rel_logits
        sdp_arc[:, 0, 1:] = float('-inf')
        sdp_arc.diagonal(0, 1, 2).fill_(float('-inf'))  # 避免自指
        sdp_label = torch.argmax(sdp_label, dim=-1)

        if mode == 'tree':
            # 语义依存树
            sdp_arc_idx = eisner(
                sdp_arc, word_attention_mask).unsqueeze_(-1).expand_as(sdp_arc)
            sdp_arc_res = torch.zeros_like(sdp_arc, dtype=torch.bool).scatter_(
                -1, sdp_arc_idx, True)
        elif mode == 'mix':
            # 混合解码
            sdp_arc_idx = eisner(
                sdp_arc, word_attention_mask).unsqueeze_(-1).expand_as(sdp_arc)
            sdp_arc_res = (sdp_arc.sigmoid_() > 0.5).scatter_(
                -1, sdp_arc_idx, True)
        else:
            # 语义依存图
            sdp_arc_res = torch.sigmoid_(sdp_arc) > 0.5

        sdp_arc_res[~word_attention_mask] = False
        sdp_label = get_graph_entities(sdp_arc_res, sdp_label, self.sdp_vocab)

        return sdp_label
Esempio n. 7
0
    def step(self, batch, preds, targets, mask) -> dict:
        seq_lens = torch.sum(mask, dim=-1, dtype=torch.int).cpu().numpy()

        parc, prel = preds
        rarc, rrel = targets

        mask = torch.cat(
            [torch.zeros_like(mask[:, :1], dtype=torch.bool), mask], dim=1)

        parc[:, 0, 1:] = float('-inf')
        parc.diagonal(0, 1, 2).fill_(float('-inf'))
        parc = eisner(parc, mask)

        # batch_size x max_len without root
        arange_index = torch.arange(1, mask.shape[1] + 1, dtype=torch.long, device=prel.device) \
            .unsqueeze(0) \
            .repeat(mask.shape[0], 1)
        app_masks = parc.ne(arange_index)
        app_masks = app_masks.unsqueeze(-1).unsqueeze(-1).repeat(
            1, 1, 1, prel.shape[-1])
        app_masks[:, :, :, 1:] = 0

        prel = prel.masked_fill(app_masks, float("-inf"))
        prel = torch.argmax(prel, dim=-1)
        prel = prel.gather(-1, parc.unsqueeze(-1)).squeeze(-1)

        parc = parc[:, 1:]
        prel = prel[:, 1:]

        punc_mask = rrel == self.punct_idx
        rrel = rrel.masked_fill(punc_mask, -1)
        prel = prel.masked_fill(punc_mask, -1)

        rarc_entities, rrel_entities, rseg_entities = self.get_graph_entities(
            rarc, rrel, seq_lens)
        parc_entities, prel_entities, pseg_entities = self.get_graph_entities(
            parc, prel, seq_lens)

        return {
            'uas_correct': len(parc_entities & rarc_entities),
            'uas_pred': len(parc_entities),
            'uas_true': len(rarc_entities),
            'las_correct': len(prel_entities & rrel_entities),
            'las_pred': len(prel_entities),
            'las_true': len(rrel_entities),
            'seg_correct': len(pseg_entities & rseg_entities),
            'seg_pred': len(pseg_entities),
            'seg_true': len(rseg_entities),
        }
Esempio n. 8
0
    def dep(self, hidden: dict, fast=True, as_tuple=True):
        """
        依存句法树
        Args:
            hidden: 分词时所得到的中间表示
            fast: 启用 fast 模式时,减少对结果的约束,速度更快,相应的精度会降低
            as_tuple: 返回的结果是否为 (idx, head, rel) 的格式,否则返回 heads, rels

        Returns:
            依存句法树结果
        """
        if len(self.dep_vocab) == 0:
            return []
        word_attention_mask = hidden['word_cls_mask']
        result = self.model.dep_classifier.forward(
            input=hidden['word_cls_input'],
            word_attention_mask=word_attention_mask[:, 1:],
            is_processed=True)
        dep_arc, dep_label = result.arc_logits, result.rel_logits
        dep_arc[:, 0, 1:] = float('-inf')
        dep_arc.diagonal(0, 1, 2).fill_(float('-inf'))
        dep_arc = dep_arc.argmax(
            dim=-1) if fast else eisner(dep_arc, word_attention_mask)

        dep_label = torch.argmax(dep_label, dim=-1)
        dep_label = dep_label.gather(-1, dep_arc.unsqueeze(-1)).squeeze(-1)

        dep_arc[~word_attention_mask] = -1
        dep_label[~word_attention_mask] = -1

        head_pred = [[item for item in arcs if item != -1]
                     for arcs in dep_arc[:, 1:].cpu().numpy().tolist()]
        rel_pred = [[self.dep_vocab[item] for item in rels if item != -1]
                    for rels in dep_label[:, 1:].cpu().numpy().tolist()]
        if not as_tuple:
            return head_pred, rel_pred
        return [[(idx + 1, head, rel)
                 for idx, (head, rel) in enumerate(zip(heads, rels))]
                for heads, rels in zip(head_pred, rel_pred)]