Beispiel #1
0
    def step(self, y_pred: torch.Tensor, y: torch.Tensor):
        y_pred = torch.argmax(y_pred, dim=-1)
        y_pred, y = self.predict(y_pred, y)

        true_entities = set(get_entities(y, self.suffix))
        pred_entities = set(get_entities(y_pred, self.suffix))

        self.nb_correct += len(true_entities & pred_entities)
        self.nb_pred += len(pred_entities)
        self.nb_true += len(true_entities)
Beispiel #2
0
 def ner(self, hidden: dict):
     # 命名实体识别
     word_length = torch.as_tensor(hidden['word_length'], device=self.device)
     ner_output = self.model.ner_decoder(hidden['word_input'], word_length)
     ner_output = torch.argmax(ner_output, dim=-1).cpu().numpy()
     ner_output = self._convert_idx_to_name(ner_output, hidden['word_length'], self.ner_vocab)
     return [get_entities(ner) for ner in ner_output]
Beispiel #3
0
 def _get_entities_with_list(self, labels_, itos):
     res = []
     for labels in labels_:
         labels = [itos[label] for label in labels]
         labels = get_entities(labels)
         res.append(labels)
     return res
Beispiel #4
0
    def seg(self, inputs: List[str]):
        length = torch.as_tensor([len(text) for text in inputs], device=self.device)
        tokenizerd = self.tokenizer.batch_encode_plus(inputs, return_tensors='pt')
        pretrained_output, *_ = self.model.pretrained(
            input_ids=tokenizerd['input_ids'].to(self.device),
            attention_mask=tokenizerd['attention_mask'].to(self.device),
            token_type_ids=tokenizerd['token_type_ids'].to(self.device)
        )

        # remove [CLS] [SEP]
        word_cls = pretrained_output[:, :1]
        char_input = torch.narrow(pretrained_output, 1, 1, pretrained_output.size(1) - 2)

        segment_output = torch.argmax(self.model.seg_decoder(char_input), dim=-1).cpu().numpy()
        segment_output = self._convert_idx_to_name(segment_output, length, self.seg_vocab)

        sentences = []
        word_idx = []
        word_length = []
        for source_text, encoding, sentence_seg_tag in zip(inputs, tokenizerd.encodings, segment_output):
            text = [source_text[start:end] for start, end in encoding.offsets[1:-1] if end != 0]

            last_word = 0
            for idx, word in enumerate(encoding.words[1:-1]):
                if word is None or is_chinese_char(text[idx][-1]):
                    continue
                if word != last_word:
                    text[idx] = ' ' + text[idx]
                    last_word = word
                else:
                    sentence_seg_tag[idx] = WORD_MIDDLE

            entities = get_entities(sentence_seg_tag)
            word_length.append(len(entities))

            sentences.append([''.join(text[entity[1]:entity[2] + 1]).strip() for entity in entities])
            word_idx.append(torch.as_tensor([entity[1] for entity in entities], device=self.device))

        word_idx = torch.nn.utils.rnn.pad_sequence(word_idx, batch_first=True)
        word_idx = word_idx.unsqueeze(-1).expand(-1, -1, char_input.shape[-1])

        word_input = torch.gather(char_input, dim=1, index=word_idx)

        word_cls_input = torch.cat([word_cls, word_input], dim=1)
        word_cls_mask = length_to_mask(torch.as_tensor(word_length, device=self.device) + 1)
        word_cls_mask[:, 0] = False  # ignore the first token of each sentence
        return sentences, {
            'word_cls': word_cls, 'word_input': word_input, 'word_length': word_length,
            'word_cls_input': word_cls_input, 'word_cls_mask': word_cls_mask
        }
Beispiel #5
0
    def predict(self, inputs, pred):
        input_text = getattr(inputs, self.text_field)
        target_len = input_text['text_length'].cpu().detach().numpy().tolist()

        pred = torch.argmax(pred, dim=-1).cpu().detach().numpy()
        text = input_text['input_ids'].cpu().detach().numpy()
        pred = self.convert_idx_to_name(pred, target_len, id2label=self.id2label)
        res = []
        for text, pred_single, length in zip(text, pred, target_len):
            text = [self.id2text.convert_ids_to_tokens(int(char)) for char in text[1:1 + length]]
            entities = get_entities(pred_single)
            sentence = ["".join(text[entity[1]:entity[2] + 1]) for entity in entities]
            res.append(sentence)
        return res
Beispiel #6
0
    def seg(self, inputs: List[str]):
        length = [len(text) for text in inputs]
        tokenizerd = self.tokenizer.batch_encode_plus(inputs, pad_to_max_length=True)
        pretrained_inputs = {key: convert(value) for key, value in tokenizerd.items()}
        cls, hidden, seg = self.onnx.run(None, pretrained_inputs)

        segment_output = self._convert_idx_to_name(seg, length, self.seg_vocab)

        word_cls = torch.as_tensor(cls, device=self.device)
        char_input = torch.as_tensor(hidden, device=self.device)

        sentences = []
        word_idx = []
        word_length = []
        for source_text, encoding, sentence_seg_tag in zip(inputs, tokenizerd.encodings, segment_output):
            text = [source_text[start:end] for start, end in encoding.offsets[1:-1] if end != 0]

            last_word = 0
            for idx, word in enumerate(encoding.words[1:-1]):
                if word is None or self._is_chinese_char(text[idx][-1]):
                    continue
                if word != last_word:
                    text[idx] = ' ' + text[idx]
                    last_word = word
                else:
                    sentence_seg_tag[idx] = WORD_MIDDLE

            entities = get_entities(sentence_seg_tag)
            word_length.append(len(entities))

            sentences.append([''.join(text[entity[1]:entity[2] + 1]).lstrip() for entity in entities])
            word_idx.append(torch.as_tensor([entity[1] for entity in entities], device=self.device))

        word_idx = torch.nn.utils.rnn.pad_sequence(word_idx, batch_first=True)
        word_idx = word_idx.unsqueeze(-1).expand(-1, -1, char_input.shape[-1])
        word_input = torch.gather(char_input, dim=1, index=word_idx)

        word_cls_input = torch.cat([word_cls, word_input], dim=1)
        word_cls_mask = length_to_mask(torch.as_tensor(word_length, device=self.device) + 1)
        word_cls_mask[:, 0] = False  # ignore the first token of each sentence

        return sentences, {
            'word_cls': word_cls, 'word_input': word_input, 'word_length': word_length,
            'word_cls_input': word_cls_input, 'word_cls_mask': word_cls_mask
        }
Beispiel #7
0
def test_span():
    span = ['B-PER', 'I-PER', 'O', 'B-LOC']
    assert get_entities(span) == [('PER', 0, 1), ('LOC', 3, 3)]
Beispiel #8
0
 def get_entities_with_list(self, labels):
     labels = [self.field.vocab.itos[label] for label in chain(*labels)]
     labels = get_entities(labels)
     return set(labels)
Beispiel #9
0
    def get_entities(self, labels):
        labels = labels.cpu().detach().numpy()
        labels = [self.field.vocab.itos[label] for label in labels]
        labels = get_entities(labels)

        return set(labels)