示例#1
0
    def __init__(self, path: str = 'small', device=None, **kwargs):
        if device is not None:
            if isinstance(device, torch.device):
                self.device = device
            elif isinstance(device, str):
                self.device = torch.device(device)
        elif torch.cuda.is_available():
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')

        if path in model_map or is_remote_url(path) or os.path.isfile(path):
            proxies = kwargs.pop("proxies", None)
            cache_dir = kwargs.pop("cache_dir", LTP_CACHE)
            force_download = kwargs.pop("force_download", False)
            resume_download = kwargs.pop("resume_download", False)
            local_files_only = kwargs.pop("local_files_only", False)
            path = cached_path(model_map.get(path, path),
                               cache_dir=cache_dir,
                               force_download=force_download,
                               proxies=proxies,
                               resume_download=resume_download,
                               local_files_only=local_files_only,
                               extract_compressed_file=True)
        elif not os.path.isdir(path):
            raise FileNotFoundError()
        try:
            ckpt = torch.load(os.path.join(path, "ltp.model"),
                              map_location=self.device)
        except Exception as e:
            fake_import_pytorch_lightning()
            ckpt = torch.load(os.path.join(path, "ltp.model"),
                              map_location=self.device)

        self.cache_dir = path
        config = AutoConfig.for_model(**ckpt['transformer_config'])
        self.model = Model(ckpt['model_config'], config=config).to(self.device)
        self.model.load_state_dict(ckpt['model'], strict=False)
        self.model.eval()
        self.max_length = self.model.transformer.config.max_position_embeddings
        self.seg_vocab = ckpt.get('seg', [WORD_MIDDLE, WORD_START])
        self.pos_vocab = ckpt.get('pos', [])
        self.ner_vocab = ckpt.get('ner', [])
        self.dep_vocab = ckpt.get('dep', [])
        self.sdp_vocab = ckpt.get('sdp', [])
        self.srl_vocab = [
            re.sub(r'ARG(\d)', r'A\1', tag.lstrip('ARGM-'))
            for tag in ckpt.get('srl', [])
        ]
        self.tokenizer = AutoTokenizer.from_pretrained(
            path, config=self.model.transformer.config, use_fast=True)
        self.trie = Trie()
示例#2
0
文件: frontend.py 项目: ztzdxqj/ltp
class LTP(object):
    model: Model
    seg_vocab: List[str]
    pos_vocab: List[str]
    ner_vocab: List[str]
    dep_vocab: List[str]
    sdp_vocab: List[str]
    srl_vocab: List[str]

    tensor: TensorType = TensorType.PYTORCH

    def __init__(self, path: str = 'small', device=None, **kwargs):
        if device is not None:
            if isinstance(device, torch.device):
                self.device = device
            elif isinstance(device, str):
                self.device = torch.device(device)
        elif torch.cuda.is_available():
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')

        if path in model_map or is_remote_url(path) or os.path.isfile(path):
            proxies = kwargs.pop("proxies", None)
            cache_dir = kwargs.pop("cache_dir", LTP_CACHE)
            force_download = kwargs.pop("force_download", False)
            resume_download = kwargs.pop("resume_download", False)
            local_files_only = kwargs.pop("local_files_only", False)
            path = cached_path(model_map.get(path, path),
                               cache_dir=cache_dir,
                               force_download=force_download,
                               proxies=proxies,
                               resume_download=resume_download,
                               local_files_only=local_files_only,
                               extract_compressed_file=True)
        elif not os.path.isdir(path):
            raise FileNotFoundError()
        try:
            ckpt = torch.load(os.path.join(path, "ltp.model"),
                              map_location=self.device)
        except Exception as e:
            fake_import_pytorch_lightning()
            ckpt = torch.load(os.path.join(path, "ltp.model"),
                              map_location=self.device)

        patch_4_1_3(ckpt)

        self.cache_dir = path
        transformer_config = ckpt['transformer_config']
        transformer_config['torchscript'] = True
        config = AutoConfig.for_model(**transformer_config)
        self.model = Model(ckpt['model_config'], config=config).to(self.device)
        self.model.load_state_dict(ckpt['model'], strict=False)
        self.model.eval()

        self.seg_vocab = ckpt.get('seg', [WORD_MIDDLE, WORD_START])
        self.seg_vocab_dict = {
            tag: idx
            for idx, tag in enumerate(self.seg_vocab)
        }
        self.pos_vocab = ckpt.get('pos', [])
        self.ner_vocab = ckpt.get('ner', [])
        self.dep_vocab = ckpt.get('dep', [])
        self.sdp_vocab = ckpt.get('sdp', [])
        self.srl_vocab = [
            re.sub(r'ARG(\d)', r'A\1', tag.lstrip('ARGM-'))
            for tag in ckpt.get('srl', [])
        ]
        self.tokenizer = AutoTokenizer.from_pretrained(
            path, config=self.model.transformer.config, use_fast=True)
        self.trie = Trie()
        self._model_version = ckpt.get('version', None)

    def __str__(self):
        return f"LTP {self.version} on {self.device} (model version: {self.model_version}) "

    def __repr__(self):
        return f"LTP {self.version} on {self.device} (model version: {self.model_version}) "

    @property
    def avaliable_models(self):
        return model_map.keys()

    @property
    def version(self):
        from ltp import __version__ as version
        return version

    @property
    def model_version(self):
        return self._model_version or 'unknown'

    @property
    def max_length(self):
        return self.model.transformer.config.max_position_embeddings

    def init_dict(self, path, max_window=None):
        self.trie.init(path, max_window)

    def add_words(self, words, max_window=None):
        self.trie.add_words(words)
        self.trie.max_window = max_window

    @staticmethod
    def sent_split(inputs: List[str], flag: str = "all", limit: int = 510):
        inputs = [
            split_sentence(text, flag=flag, limit=limit) for text in inputs
        ]
        inputs = list(itertools.chain(*inputs))
        return inputs

    def seg_with_dict(self, inputs: List[str], tokenized: BatchEncoding,
                      batch_prefix):
        # 进行正向字典匹配
        matching = []
        for source_text, encoding, preffix in zip(inputs, tokenized.encodings,
                                                  batch_prefix):
            text = [
                source_text[start:end] for start, end in encoding.offsets[1:-1]
                if end != 0
            ]
            matching_pos = self.trie.maximum_forward_matching(text, preffix)
            matching.append(matching_pos)
        return matching

    @no_gard
    def _seg(self, tokenizerd, is_preseged=False):
        input_ids = tokenizerd['input_ids'].to(self.device)
        attention_mask = tokenizerd['attention_mask'].to(self.device)
        token_type_ids = tokenizerd['token_type_ids'].to(self.device)
        length = torch.sum(attention_mask, dim=-1) - 2

        pretrained_output, *_ = self.model.transformer(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            return_dict=False)

        # remove [CLS] [SEP]
        word_cls = pretrained_output[:, :1]
        char_input = torch.narrow(pretrained_output, 1, 1,
                                  pretrained_output.size(1) - 2)
        if is_preseged:
            segment_output = None
        else:
            segment_output = torch.argmax(
                self.model.seg_classifier(char_input).logits,
                dim=-1).cpu().numpy()
        return word_cls, char_input, segment_output, length

    @no_gard
    def seg(self,
            inputs: Union[List[str], List[List[str]]],
            truncation: bool = True,
            is_preseged=False):
        """
        分词

        Args:
            inputs: 句子列表
            truncation: 是否对过长的句子进行截断,如果为 False 可能会抛出异常
            is_preseged:  是否已经进行过分词

        Returns:
            words: 分词后的序列
            hidden: 用于其他任务的中间表示
        """

        if transformers_version.major >= 3 and transformers_version.major > 1:
            kwargs = {'is_split_into_words': is_preseged}
        else:
            kwargs = {'is_pretokenized': is_preseged}

        tokenized = self.tokenizer.batch_encode_plus(
            inputs,
            padding=True,
            truncation=truncation,
            return_tensors=self.tensor,
            max_length=self.max_length,
            **kwargs)
        cls, hidden, seg, lengths = self._seg(tokenized,
                                              is_preseged=is_preseged)

        batch_prefix = [[
            word_idx != encoding.words[idx - 1]
            for idx, word_idx in enumerate(encoding.words)
            if word_idx is not None
        ] for encoding in tokenized.encodings]

        # merge segments with maximum forward matching
        if self.trie.is_init and not is_preseged:
            matches = self.seg_with_dict(inputs, tokenized, batch_prefix)
            for sent_match, sent_seg in zip(matches, seg):
                for start, end in sent_match:
                    sent_seg[start] = self.seg_vocab_dict[WORD_START]
                    sent_seg[start + 1:end] = self.seg_vocab_dict[WORD_MIDDLE]
                    if end < len(sent_seg):
                        sent_seg[end] = self.seg_vocab_dict[WORD_START]

        if is_preseged:
            sentences = inputs
            word_length = [len(sentence) for sentence in sentences]

            word_idx = []
            for encodings in tokenized.encodings:
                sentence_word_idx = []
                for idx, (start, end) in enumerate(encodings.offsets[1:]):
                    if start == 0 and end != 0:
                        sentence_word_idx.append(idx)
                word_idx.append(
                    torch.as_tensor(sentence_word_idx, device=self.device))
        else:
            segment_output = convert_idx_to_name(seg, lengths, self.seg_vocab)
            sentences = []
            word_idx = []
            word_length = []

            for source_text, length, encoding, seg_tag, preffix in \
                    zip(inputs, lengths, tokenized.encodings, segment_output, batch_prefix):
                offsets = encoding.offsets[1:length + 1]
                text = []
                last_offset = None
                for start, end in offsets:
                    text.append('' if last_offset == (
                        start, end) else source_text[start:end])
                    last_offset = (start, end)

                for idx in range(1, length):
                    current_beg = offsets[idx][0]
                    forward_end = offsets[idx - 1][-1]
                    if forward_end < current_beg:
                        text[idx] = source_text[
                            forward_end:current_beg] + text[idx]
                    if not preffix[idx]:
                        seg_tag[idx] = WORD_MIDDLE

                entities = get_entities(seg_tag)
                word_length.append(len(entities))
                sentences.append([
                    ''.join(text[entity[1]:entity[2] + 1]).strip()
                    for entity in entities
                ])
                word_idx.append(
                    torch.as_tensor([entity[1] for entity in entities],
                                    device=self.device))

        word_idx = torch.nn.utils.rnn.pad_sequence(word_idx, batch_first=True)
        word_idx = word_idx.unsqueeze(-1).expand(-1, -1,
                                                 hidden.shape[-1])  # 展开

        word_input = torch.gather(hidden, dim=1,
                                  index=word_idx)  # 每个word第一个char的向量

        if len(self.dep_vocab) + len(self.sdp_vocab) > 0:
            word_cls_input = torch.cat([cls, word_input], dim=1)
            word_cls_mask = length_to_mask(
                torch.as_tensor(word_length, device=self.device) + 1)
            word_cls_mask[:, 0] = False
        else:
            word_cls_input, word_cls_mask = None, None

        return sentences, {
            'word_cls': cls,
            'word_input': word_input,
            'word_length': word_length,
            'word_cls_input': word_cls_input,
            'word_cls_mask': word_cls_mask
        }

    @no_gard
    def pos(self, hidden: dict):
        """
        词性标注
        Args:
            hidden: 分词时所得到的中间表示

        Returns:
            pos: 词性标注结果
        """
        if len(self.pos_vocab) == 0:
            return []
        postagger_output = self.model.pos_classifier(
            hidden['word_input']).logits
        postagger_output = torch.argmax(postagger_output, dim=-1).cpu().numpy()
        postagger_output = convert_idx_to_name(postagger_output,
                                               hidden['word_length'],
                                               self.pos_vocab)
        return postagger_output

    @no_gard
    def ner(self, hidden: dict, as_entities=True):
        """
        命名实体识别
        Args:
            hidden: 分词时所得到的中间表示
            as_entities: 是否以 Entity(Type, Start, End) 的形式返回

        Returns:
            pos: 命名实体识别结果
        """
        if len(self.ner_vocab) == 0:
            return []
        ner_output = self.model.ner_classifier.forward(
            hidden['word_input'],
            word_attention_mask=hidden['word_cls_mask'][:, 1:])
        ner_output = ner_output.decoded or torch.argmax(ner_output.logits,
                                                        dim=-1).cpu().numpy()
        ner_output = convert_idx_to_name(ner_output, hidden['word_length'],
                                         self.ner_vocab)
        return [get_entities(ner)
                for ner in ner_output] if as_entities else ner_output

    @no_gard
    def srl(self, hidden: dict, keep_empty=True):
        """
        语义角色标注
        Args:
            hidden: 分词时所得到的中间表示

        Returns:
            pos: 语义角色标注结果
        """
        if len(self.srl_vocab) == 0:
            return []
        srl_output = self.model.srl_classifier.forward(
            input=hidden['word_input'],
            word_attention_mask=hidden['word_cls_mask'][:, 1:]).decoded
        srl_entities = get_entities_with_list(srl_output, self.srl_vocab)

        srl_labels_res = []
        for length in hidden['word_length']:
            srl_labels_res.append([])
            curr_srl_labels, srl_entities = srl_entities[:
                                                         length], srl_entities[
                                                             length:]
            srl_labels_res[-1].extend(curr_srl_labels)

        if not keep_empty:
            srl_labels_res = [[(idx, labels)
                               for idx, labels in enumerate(srl_labels)
                               if len(labels)]
                              for srl_labels in srl_labels_res]
        return srl_labels_res

    @no_gard
    def dep(self, hidden: dict, fast=True, as_tuple=True):
        """
        依存句法树
        Args:
            hidden: 分词时所得到的中间表示
            fast: 启用 fast 模式时,减少对结果的约束,速度更快,相应的精度会降低
            as_tuple: 返回的结果是否为 (idx, head, rel) 的格式,否则返回 heads, rels

        Returns:
            依存句法树结果
        """
        if len(self.dep_vocab) == 0:
            return []
        word_attention_mask = hidden['word_cls_mask']
        result = self.model.dep_classifier.forward(
            input=hidden['word_cls_input'],
            word_attention_mask=word_attention_mask[:, 1:])
        dep_arc, dep_label = result.arc_logits, result.rel_logits
        dep_arc[:, 0, 1:] = float('-inf')
        dep_arc.diagonal(0, 1, 2).fill_(float('-inf'))
        dep_arc = dep_arc.argmax(
            dim=-1) if fast else eisner(dep_arc, word_attention_mask)

        dep_label = torch.argmax(dep_label, dim=-1)
        dep_label = dep_label.gather(-1, dep_arc.unsqueeze(-1)).squeeze(-1)

        dep_arc[~word_attention_mask] = -1
        dep_label[~word_attention_mask] = -1

        head_pred = [[item for item in arcs if item != -1]
                     for arcs in dep_arc[:, 1:].cpu().numpy().tolist()]
        rel_pred = [[self.dep_vocab[item] for item in rels if item != -1]
                    for rels in dep_label[:, 1:].cpu().numpy().tolist()]
        if not as_tuple:
            return head_pred, rel_pred
        return [[(idx + 1, head, rel)
                 for idx, (head, rel) in enumerate(zip(heads, rels))]
                for heads, rels in zip(head_pred, rel_pred)]

    @no_gard
    def sdp(self, hidden: dict, mode: str = 'graph'):
        """
        语义依存图(树)
        Args:
            hidden: 分词时所得到的中间表示
            mode: ['tree', 'graph', 'mix']

        Returns:
            语义依存图(树)结果
        """
        if len(self.sdp_vocab) == 0:
            return []

        word_attention_mask = hidden['word_cls_mask']
        result = self.model.sdp_classifier(
            input=hidden['word_cls_input'],
            word_attention_mask=word_attention_mask[:, 1:])
        sdp_arc, sdp_label = result.arc_logits, result.rel_logits
        sdp_arc[:, 0, 1:] = float('-inf')
        sdp_arc.diagonal(0, 1, 2).fill_(float('-inf'))  # 避免自指
        sdp_label = torch.argmax(sdp_label, dim=-1)

        if mode == 'tree':
            # 语义依存树
            sdp_arc_idx = eisner(
                sdp_arc, word_attention_mask).unsqueeze_(-1).expand_as(sdp_arc)
            sdp_arc_res = torch.zeros_like(sdp_arc, dtype=torch.bool).scatter_(
                -1, sdp_arc_idx, True)
        elif mode == 'mix':
            # 混合解码
            sdp_arc_idx = eisner(
                sdp_arc, word_attention_mask).unsqueeze_(-1).expand_as(sdp_arc)
            sdp_arc_res = (sdp_arc.sigmoid_() > 0.5).scatter_(
                -1, sdp_arc_idx, True)
        else:
            # 语义依存图
            sdp_arc_res = torch.sigmoid_(sdp_arc) > 0.5

        sdp_arc_res[~word_attention_mask] = False
        sdp_label = get_graph_entities(sdp_arc_res, sdp_label, self.sdp_vocab)

        return sdp_label