Esempio n. 1
0
    def sdp(self, hidden: dict, graph=True):
        # 语义依存
        sdp_arc, sdp_label, _ = self.model.sdp_decoder(hidden['word_cls_input'], hidden['word_length'])
        sdp_arc = torch.sigmoid_(sdp_arc)

        if graph:
            # 语义依存图
            sdp_arc.transpose_(-1, -2)
            sdp_root_mask = sdp_arc[:, 0].argmax(dim=-1).unsqueeze_(-1).expand_as(sdp_arc[:, 0])
            sdp_arc[:, 0] = 0
            sdp_arc[:, 0].scatter_(dim=-1, index=sdp_root_mask, value=1)
            sdp_arc_T = sdp_arc.transpose(-1, -2)
            sdp_arc_fix = sdp_arc_T.argmax(dim=-1).unsqueeze_(-1).expand_as(sdp_arc)
            sdp_arc = ((sdp_arc_T > 0.5) & (sdp_arc_T > sdp_arc)). \
                scatter_(dim=-1, index=sdp_arc_fix, value=True)
        else:
            # 语义依存树
            sdp_arc_fix = eisner(sdp_arc, hidden['word_cls_mask']).unsqueeze_(-1).expand_as(sdp_arc)
            sdp_arc = torch.zeros_like(sdp_arc, dtype=torch.bool).scatter_(dim=-1, index=sdp_arc_fix, value=True)

        sdp_label = torch.argmax(sdp_label, dim=-1)

        word_cls_mask = hidden['word_cls_mask']
        word_cls_mask = word_cls_mask.unsqueeze(-1).expand(-1, -1, word_cls_mask.size(1))
        sdp_arc = sdp_arc & word_cls_mask
        sdp_label = self._get_graph_entities(sdp_arc, sdp_label, self.sdp_vocab)

        return sdp_label
Esempio n. 2
0
File: ltp.py Progetto: xwsss1/ltp
    def dep(self, hidden: dict, fast=False):
        # 依存句法树
        dep_arc, dep_label, word_length = self.model.dep_decoder(
            hidden['word_cls_input'], hidden['word_length'])
        if fast:
            dep_arc_fix = dep_arc.argmax(
                dim=-1).unsqueeze_(-1).expand_as(dep_arc)
        else:
            dep_arc_fix = eisner(
                dep_arc,
                hidden['word_cls_mask']).unsqueeze_(-1).expand_as(dep_arc)
        dep_arc = torch.zeros_like(dep_arc, dtype=torch.bool).scatter_(
            dim=-1, index=dep_arc_fix, value=True)

        dep_label[:, :, :, self.dep_fix:] = float('-inf')
        dep_label = torch.argmax(dep_label, dim=-1)

        word_cls_mask = hidden['word_cls_mask']
        word_cls_mask = word_cls_mask.unsqueeze(-1).expand(
            -1, -1, word_cls_mask.size(1))
        dep_arc = dep_arc & word_cls_mask
        dep_label = self._get_graph_entities(dep_arc, dep_label,
                                             self.dep_vocab)

        return dep_label
Esempio n. 3
0
    def dep(self, hidden: dict, fast=False):
        """
        依存句法树
        Args:
            hidden: 分词时所得到的中间表示
            fast: 启用 fast 模式时,减少对结果的约束,速度更快,相应的精度会降低

        Returns:
            依存句法树结果
        """
        dep_arc, dep_label, word_length = self.model.dep_decoder(
            hidden['word_cls_input'], hidden['word_length'])
        if fast:
            dep_arc_fix = dep_arc.argmax(
                dim=-1).unsqueeze_(-1).expand_as(dep_arc)
        else:
            dep_arc_fix = eisner(
                dep_arc,
                hidden['word_cls_mask']).unsqueeze_(-1).expand_as(dep_arc)
        dep_arc = torch.zeros_like(dep_arc, dtype=torch.bool).scatter_(
            -1, dep_arc_fix, True)
        dep_label = torch.argmax(dep_label, dim=-1)

        word_cls_mask = hidden['word_cls_mask']
        word_cls_mask = word_cls_mask.unsqueeze(-1).expand(
            -1, -1, word_cls_mask.size(1))
        dep_arc = dep_arc & word_cls_mask
        dep_label = get_graph_entities(dep_arc, dep_label, self.dep_vocab)

        return dep_label
Esempio n. 4
0
    def step(self, y_pred: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
             y: Tuple[torch.Tensor, torch.Tensor]):
        arc_pred, label_pred, seq_len = y_pred
        mask = length_to_mask(seq_len + 1)
        mask[:, 0] = False
        if self._eisner:
            from ltp.utils import eisner
            arc_pred = eisner(arc_pred, mask)
        else:
            arc_pred = torch.argmax(arc_pred, dim=-1)
        label_pred = torch.argmax(label_pred, dim=-1)

        arc_real, label_real = y
        label_pred = label_pred.gather(-1, arc_pred.unsqueeze(-1)).squeeze(-1)

        mask = mask.narrow(-1, 1, mask.size(1) - 1)
        arc_pred = arc_pred.narrow(-1, 1, arc_pred.size(1) - 1)
        label_pred = label_pred.narrow(-1, 1, label_pred.size(1) - 1)

        head_true = (arc_pred == arc_real)[mask]
        label_true = (label_pred == label_real)[mask]

        self._head_true += torch.sum(head_true).item()
        self._label_true += torch.sum(label_true).item()
        self._union_true += torch.sum(label_true[head_true]).item()
        self._all += torch.sum(mask).item()
Esempio n. 5
0
File: ltp.py Progetto: pyw123/ltp
    def sdp(self, hidden: dict, graph=True):
        """
        语义依存图(树)
        Args:
            hidden: 分词时所得到的中间表示
            graph: 选择是语义依存图还是语义依存树结果

        Returns:
            语义依存图(树)结果
        """
        sdp_arc, sdp_label, _ = self.model.sdp_decoder(
            hidden['word_cls_input'], hidden['word_length'])
        sdp_arc = torch.sigmoid_(sdp_arc)

        # 避免自指
        eye = torch.arange(0, sdp_arc.size(1), device=sdp_arc.device).view(
            1, 1, -1).expand(sdp_arc.size(0), -1, -1)
        sdp_arc.scatter_(1, eye, 0)

        if graph:
            # 语义依存图
            sdp_arc.transpose_(-1, -2)
            sdp_root_mask = sdp_arc[:,
                                    0].argmax(dim=-1).unsqueeze_(-1).expand_as(
                                        sdp_arc[:, 0])
            sdp_arc[:, 0] = 0
            sdp_arc[:, 0].scatter_(dim=-1, index=sdp_root_mask, value=1)
            sdp_arc_T = sdp_arc.transpose(-1, -2)
            sdp_arc_fix = sdp_arc_T.argmax(
                dim=-1).unsqueeze_(-1).expand_as(sdp_arc)
            sdp_arc = ((sdp_arc_T > 0.5) & (sdp_arc_T > sdp_arc)).scatter_(
                -1, sdp_arc_fix, True)
        else:
            # 语义依存树
            sdp_arc_fix = eisner(
                sdp_arc,
                hidden['word_cls_mask']).unsqueeze_(-1).expand_as(sdp_arc)
            sdp_arc = torch.zeros_like(sdp_arc, dtype=torch.bool).scatter_(
                -1, sdp_arc_fix, True)

        sdp_label = torch.argmax(sdp_label, dim=-1)

        word_cls_mask = hidden['word_cls_mask']
        word_cls_mask = word_cls_mask.unsqueeze(-1).expand(
            -1, -1, word_cls_mask.size(1))
        word_cls_mask = word_cls_mask & word_cls_mask.transpose(-1, -2)
        for si, length in enumerate(hidden['word_length']):
            word_cls_mask[si, :length, 0] = True
        sdp_arc = sdp_arc & word_cls_mask
        sdp_label = get_graph_entities(sdp_arc, sdp_label, self.sdp_vocab)

        return sdp_label