Example #1
0
    def step(self, y_pred: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
             y: Tuple[torch.Tensor, torch.Tensor]):
        arc_pred, label_pred, seq_len = y_pred
        mask = length_to_mask(seq_len + 1)
        mask[:, 0] = False
        if self._eisner:
            from ltp.utils import eisner
            arc_pred = eisner(arc_pred, mask)
        else:
            arc_pred = torch.argmax(arc_pred, dim=-1)
        label_pred = torch.argmax(label_pred, dim=-1)

        arc_real, label_real = y
        label_pred = label_pred.gather(-1, arc_pred.unsqueeze(-1)).squeeze(-1)

        mask = mask.narrow(-1, 1, mask.size(1) - 1)
        arc_pred = arc_pred.narrow(-1, 1, arc_pred.size(1) - 1)
        label_pred = label_pred.narrow(-1, 1, label_pred.size(1) - 1)

        head_true = (arc_pred == arc_real)[mask]
        label_true = (label_pred == label_real)[mask]

        self._head_true += torch.sum(head_true).item()
        self._label_true += torch.sum(label_true).item()
        self._union_true += torch.sum(label_true[head_true]).item()
        self._all += torch.sum(mask).item()
Example #2
0
File: ltp.py Project: xwsss1/ltp
    def srl(self, hidden: dict, keep_empty=True):
        # 语义角色标注
        word_length = torch.as_tensor(hidden['word_length'],
                                      device=hidden['word_input'].device)
        word_mask = length_to_mask(word_length)
        srl_output, srl_length, crf = self.model.srl_decoder(
            hidden['word_input'], hidden['word_length'])
        mask = word_mask.unsqueeze_(-1).expand(-1, -1, word_mask.size(1))
        mask = (mask & mask.transpose(-1, -2)).flatten(end_dim=1)
        index = mask[:, 0]
        mask = mask[index]

        srl_entities = crf.decode(srl_output.flatten(end_dim=1)[index], mask)
        srl_entities = self._get_entities_with_list(srl_entities,
                                                    self.srl_vocab)

        srl_labels_res = []
        for length in srl_length:
            srl_labels_res.append([])
            curr_srl_labels, srl_entities = srl_entities[:
                                                         length], srl_entities[
                                                             length:]
            srl_labels_res[-1].extend(curr_srl_labels)

        if not keep_empty:
            srl_labels_res = [[(idx, labels)
                               for idx, labels in enumerate(srl_labels)
                               if len(labels)]
                              for srl_labels in srl_labels_res]
        return srl_labels_res
    def forward(self, x: Tensor, length: Tensor, gold: Optional = None):
        """
        :param x: batch_size x max_len
        :param length: sequence length, B
        """

        mask = length_to_mask(length, dtype=torch.long)

        for layer in self.layers:
            x = layer(x, mask)
        return x, length, gold
Example #4
0
    def distill(self, inputs, targets, temperature_calc, distill_loss, gold=None):
        emissions, seq_lens, crf = inputs
        emissions_T, _, crf_T = targets

        mask = length_to_mask(seq_lens)
        mask = mask.unsqueeze_(-1).expand(-1, -1, mask.size(1))
        mask = mask & mask.transpose(-1, -2)

        logits_loss = F.mse_loss(emissions[mask], emissions_T[mask])
        crf_loss = F.mse_loss(crf.transitions, crf_T.transitions) + \
                   F.mse_loss(crf.start_transitions, crf_T.start_transitions) + \
                   F.mse_loss(crf.end_transitions, crf_T.end_transitions)

        return logits_loss + crf_loss
Example #5
0
    def seg(self, inputs: List[str]):
        length = torch.as_tensor([len(text) for text in inputs], device=self.device)
        tokenizerd = self.tokenizer.batch_encode_plus(inputs, return_tensors='pt')
        pretrained_output, *_ = self.model.pretrained(
            input_ids=tokenizerd['input_ids'].to(self.device),
            attention_mask=tokenizerd['attention_mask'].to(self.device),
            token_type_ids=tokenizerd['token_type_ids'].to(self.device)
        )

        # remove [CLS] [SEP]
        word_cls = pretrained_output[:, :1]
        char_input = torch.narrow(pretrained_output, 1, 1, pretrained_output.size(1) - 2)

        segment_output = torch.argmax(self.model.seg_decoder(char_input), dim=-1).cpu().numpy()
        segment_output = self._convert_idx_to_name(segment_output, length, self.seg_vocab)

        sentences = []
        word_idx = []
        word_length = []
        for source_text, encoding, sentence_seg_tag in zip(inputs, tokenizerd.encodings, segment_output):
            text = [source_text[start:end] for start, end in encoding.offsets[1:-1] if end != 0]

            last_word = 0
            for idx, word in enumerate(encoding.words[1:-1]):
                if word is None or is_chinese_char(text[idx][-1]):
                    continue
                if word != last_word:
                    text[idx] = ' ' + text[idx]
                    last_word = word
                else:
                    sentence_seg_tag[idx] = WORD_MIDDLE

            entities = get_entities(sentence_seg_tag)
            word_length.append(len(entities))

            sentences.append([''.join(text[entity[1]:entity[2] + 1]).strip() for entity in entities])
            word_idx.append(torch.as_tensor([entity[1] for entity in entities], device=self.device))

        word_idx = torch.nn.utils.rnn.pad_sequence(word_idx, batch_first=True)
        word_idx = word_idx.unsqueeze(-1).expand(-1, -1, char_input.shape[-1])

        word_input = torch.gather(char_input, dim=1, index=word_idx)

        word_cls_input = torch.cat([word_cls, word_input], dim=1)
        word_cls_mask = length_to_mask(torch.as_tensor(word_length, device=self.device) + 1)
        word_cls_mask[:, 0] = False  # ignore the first token of each sentence
        return sentences, {
            'word_cls': word_cls, 'word_input': word_input, 'word_length': word_length,
            'word_cls_input': word_cls_input, 'word_cls_mask': word_cls_mask
        }
Example #6
0
    def seg(self, inputs: List[str]):
        length = [len(text) for text in inputs]
        tokenizerd = self.tokenizer.batch_encode_plus(inputs, pad_to_max_length=True)
        pretrained_inputs = {key: convert(value) for key, value in tokenizerd.items()}
        cls, hidden, seg = self.onnx.run(None, pretrained_inputs)

        segment_output = self._convert_idx_to_name(seg, length, self.seg_vocab)

        word_cls = torch.as_tensor(cls, device=self.device)
        char_input = torch.as_tensor(hidden, device=self.device)

        sentences = []
        word_idx = []
        word_length = []
        for source_text, encoding, sentence_seg_tag in zip(inputs, tokenizerd.encodings, segment_output):
            text = [source_text[start:end] for start, end in encoding.offsets[1:-1] if end != 0]

            last_word = 0
            for idx, word in enumerate(encoding.words[1:-1]):
                if word is None or self._is_chinese_char(text[idx][-1]):
                    continue
                if word != last_word:
                    text[idx] = ' ' + text[idx]
                    last_word = word
                else:
                    sentence_seg_tag[idx] = WORD_MIDDLE

            entities = get_entities(sentence_seg_tag)
            word_length.append(len(entities))

            sentences.append([''.join(text[entity[1]:entity[2] + 1]).lstrip() for entity in entities])
            word_idx.append(torch.as_tensor([entity[1] for entity in entities], device=self.device))

        word_idx = torch.nn.utils.rnn.pad_sequence(word_idx, batch_first=True)
        word_idx = word_idx.unsqueeze(-1).expand(-1, -1, char_input.shape[-1])
        word_input = torch.gather(char_input, dim=1, index=word_idx)

        word_cls_input = torch.cat([word_cls, word_input], dim=1)
        word_cls_mask = length_to_mask(torch.as_tensor(word_length, device=self.device) + 1)
        word_cls_mask[:, 0] = False  # ignore the first token of each sentence

        return sentences, {
            'word_cls': word_cls, 'word_input': word_input, 'word_length': word_length,
            'word_cls_input': word_cls_input, 'word_cls_mask': word_cls_mask
        }
Example #7
0
    def predict(self, inputs, pred):
        srl_output, srl_length = pred
        mask = length_to_mask(srl_length)

        mask = mask.unsqueeze_(-1).expand(-1, -1, mask.size(1))
        mask = (mask & mask.transpose(-1, -2)).flatten(end_dim=1)
        index = mask[:, 0]
        mask = mask[index]

        srl_output = srl_output.flatten(end_dim=1)[index]
        srl_labels = torch.argmax(srl_output, dim=-1).cpu().numpy()
        srl_labels = self._convert_idx_to_name(srl_labels, mask.sum(dim=1))
        # srl_labels_res = []
        # for length in srl_length:
        #     srl_labels_res.append([])
        #     curr_srl_labels, srl_labels = srl_labels[:length], srl_labels[length:]
        #     srl_labels_res[-1].extend([get_entities(labels) for labels in curr_srl_labels])

        return srl_labels
Example #8
0
    def distill(self, inputs, targets, temperature_calc, distill_loss, gold=None):
        arc_scores, rel_scores, seq_lens = inputs
        arc_scores_T, rel_scores_T, _ = targets

        mask = length_to_mask(seq_lens + 1)
        mask[:, 0] = False  # ignore the first token of each sentence

        arc_logits = select_logits_with_mask(arc_scores, mask)
        arc_logits_T = select_logits_with_mask(arc_scores_T, mask)
        arc_temperature = temperature_calc(arc_logits, arc_logits_T)

        rel_logits = select_logits_with_mask(rel_scores, mask)
        rel_logits_T = select_logits_with_mask(rel_scores_T, mask)
        rel_temperature = temperature_calc(rel_logits, rel_logits_T)

        loss = 2 * ((1 - self.loss_interpolation) * self.kd_ce_loss(arc_logits, arc_logits_T, arc_temperature)
                    + self.loss_interpolation * self.kd_ce_loss(rel_logits, rel_logits_T, rel_temperature))

        return loss
Example #9
0
    def forward(self, inputs, targets):
        emissions, seq_lens, crf = inputs
        rel_gold, rel_gold_set = targets

        mask = length_to_mask(seq_lens)
        mask = mask.unsqueeze_(-1).expand_as(rel_gold)
        mask = mask & mask.transpose(-1, -2)
        mask = mask.flatten(end_dim=1)
        index = mask[:, 0]

        mask = mask[index]
        emissions = emissions.flatten(end_dim=1)[index]
        rel_gold = rel_gold.flatten(end_dim=1)[index]

        if self.cross_entropy:
            cross_entropy = F.cross_entropy(emissions[mask], rel_gold[mask])
            crf_loss = crf.forward(emissions=emissions, tags=rel_gold, mask=mask, reduction=self.reduction)
            return cross_entropy - crf_loss
        else:
            return -crf.forward(emissions=emissions, tags=rel_gold, mask=mask, reduction=self.reduction)
Example #10
0
    def forward(self, inputs, targets):
        arcs, rels = targets
        arc_scores, rel_scores, seq_lens = inputs
        mask = length_to_mask(seq_lens + 1, dtype=torch.float)
        mask[:, 0] = 0  # ignore the first token of each sentence

        mask = mask.unsqueeze(-1)
        mask = mask.expand_as(arcs)

        arc_loss = F.binary_cross_entropy_with_logits(
            arc_scores, arcs, weight=mask, reduction=self.reduction
        )

        num_tags = rel_scores.shape[-1]
        rel_loss = F.cross_entropy(
            rel_scores.contiguous().view((-1, num_tags)), rels.contiguous().view(-1), weight=self.weight,
            ignore_index=self.ignore_index, reduction=self.reduction
        )
        loss = 2 * ((1 - self.loss_interpolation) * arc_loss + self.loss_interpolation * rel_loss)
        return loss
Example #11
0
    def distill(self, inputs, targets, temperature_calc, distill_loss, gold=None):
        arc_scores, rel_scores, seq_lens = inputs
        arc_scores_T, rel_scores_T, _ = targets

        mask = length_to_mask(seq_lens + 1)
        mask[:, 0] = False

        arc_mask = mask.unsqueeze(-1).expand_as(arc_scores)
        arc_logits = torch.sigmoid(arc_scores)[arc_mask]
        arc_logits_T = torch.sigmoid(arc_scores_T)[arc_mask]
        arc_temperature = temperature_calc(arc_logits, arc_logits_T)

        rel_logits = select_logits_with_mask(rel_scores, mask)
        rel_logits_T = select_logits_with_mask(rel_scores_T, mask)
        rel_temperature = temperature_calc(rel_logits, rel_logits_T)

        loss = 2 * ((1 - self.loss_interpolation) * F.mse_loss(arc_logits / arc_temperature,
                                                               arc_logits_T / arc_temperature)
                    + self.loss_interpolation * self.kd_ce_loss(rel_logits, rel_logits_T, rel_temperature))
        return loss
Example #12
0
    def step(self, y_pred: Tuple[torch.Tensor, torch.Tensor, Any],
             y: Tuple[torch.Tensor, set]):
        rel_gold, rels_gold_set = y
        rels_scores, seq_lens, crf = y_pred

        mask = length_to_mask(seq_lens)
        mask = mask.unsqueeze_(-1).expand(-1, -1, mask.size(1))
        mask = mask & mask.transpose(-1, -2)
        mask = mask.flatten(end_dim=1)
        index = mask[:, 0]

        rel_gold = rel_gold.flatten(end_dim=1)[index]

        mask = mask[index]
        pred_entities = crf.decode(rels_scores.flatten(end_dim=1)[index], mask)

        rel_entities = self.get_entities(rel_gold[mask])
        pred_entities = self.get_entities_with_list(pred_entities)

        self.nb_correct += len(rel_entities & pred_entities)
        self.nb_pred += len(pred_entities)
        self.nb_true += len(rel_entities)
Example #13
0
    def step(self, y_pred: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
             y: Tuple[torch.Tensor, torch.Tensor]):
        arc_pred, label_pred, seq_len = y_pred
        arc_real, label_real = y

        arc_real = arc_real > 0.5
        arc_pred = torch.sigmoid(arc_pred) > 0.5
        # to [B, L+1, L+1]
        label_pred = torch.argmax(label_pred, dim=-1)

        mask = length_to_mask(seq_len + 1)
        mask[:, 0] = False  # ignore the first token of each sentence
        mask = mask.unsqueeze(-1).expand_as(arc_pred)

        arc_pred[mask == False] = False

        true_entities = self.get_entities(arc_real, label_real)
        pred_entities = self.get_entities(arc_pred, label_pred)

        self.nb_correct += len(true_entities & pred_entities)
        self.nb_pred += len(pred_entities)
        self.nb_true += len(true_entities)
Example #14
0
    def forward(self, inputs, targets):
        arcs, rels = targets
        arc_scores, rel_scores, seq_lens = inputs
        mask = length_to_mask(seq_lens + 1)
        mask[:, 0] = False  # ignore the first token of each sentence

        arc_scores, rel_scores = arc_scores[mask], rel_scores[mask]

        # for taget not bos
        mask = torch.narrow(mask, dim=-1, start=1, length=mask.size(1) - 1)
        arcs, rels = arcs[mask], rels[mask]
        rel_scores = rel_scores[torch.arange(len(arcs)), arcs]

        arc_loss = F.cross_entropy(
            arc_scores, arcs, weight=None,
            ignore_index=self.ignore_index, reduction=self.reduction
        )
        rel_loss = F.cross_entropy(
            rel_scores, rels, weight=None,
            ignore_index=self.ignore_index, reduction=self.reduction
        )
        loss = 2 * ((1 - self.loss_interpolation) * arc_loss + self.loss_interpolation * rel_loss)
        return loss
Example #15
0
 def step(self, y_pred, y: dict):
     mask = ~ length_to_mask(y['text_length'])
     target = y['word_idn']
     target[mask] = -1
     super(Segment, self).step(y_pred, target)
Example #16
0
    def seg(self,
            inputs: Union[List[str], List[List[str]]],
            truncation: bool = True,
            is_preseged=False):
        """
        分词

        Args:
            inputs: 句子列表
            truncation: 是否对过长的句子进行截断,如果为 False 可能会抛出异常
            is_preseged:  是否已经进行过分词

        Returns:
            words: 分词后的序列
            hidden: 用于其他任务的中间表示
        """

        if transformers_version.major >= 3 and transformers_version.major > 1:
            kwargs = {'is_split_into_words': is_preseged}
        else:
            kwargs = {'is_pretokenized': is_preseged}

        tokenized = self.tokenizer.batch_encode_plus(
            inputs,
            padding=True,
            truncation=truncation,
            return_tensors=self.tensor,
            max_length=self.max_length,
            **kwargs)
        cls, hidden, seg, lengths = self._seg(tokenized,
                                              is_preseged=is_preseged)

        batch_prefix = [[
            word_idx != encoding.words[idx - 1]
            for idx, word_idx in enumerate(encoding.words)
            if word_idx is not None
        ] for encoding in tokenized.encodings]

        # merge segments with maximum forward matching
        if self.trie.is_init and not is_preseged:
            matches = self.seg_with_dict(inputs, tokenized, batch_prefix)
            for sent_match, sent_seg in zip(matches, seg):
                for start, end in sent_match:
                    sent_seg[start] = self.seg_vocab_dict[WORD_START]
                    sent_seg[start + 1:end] = self.seg_vocab_dict[WORD_MIDDLE]
                    if end < len(sent_seg):
                        sent_seg[end] = self.seg_vocab_dict[WORD_START]

        if is_preseged:
            sentences = inputs
            word_length = [len(sentence) for sentence in sentences]

            word_idx = []
            for encodings in tokenized.encodings:
                sentence_word_idx = []
                for idx, (start, end) in enumerate(encodings.offsets[1:]):
                    if start == 0 and end != 0:
                        sentence_word_idx.append(idx)
                word_idx.append(
                    torch.as_tensor(sentence_word_idx, device=self.device))
        else:
            segment_output = convert_idx_to_name(seg, lengths, self.seg_vocab)
            sentences = []
            word_idx = []
            word_length = []

            for source_text, length, encoding, seg_tag, preffix in \
                    zip(inputs, lengths, tokenized.encodings, segment_output, batch_prefix):
                offsets = encoding.offsets[1:length + 1]
                text = []
                last_offset = None
                for start, end in offsets:
                    text.append('' if last_offset == (
                        start, end) else source_text[start:end])
                    last_offset = (start, end)

                for idx in range(1, length):
                    current_beg = offsets[idx][0]
                    forward_end = offsets[idx - 1][-1]
                    if forward_end < current_beg:
                        text[idx] = source_text[
                            forward_end:current_beg] + text[idx]
                    if not preffix[idx]:
                        seg_tag[idx] = WORD_MIDDLE

                entities = get_entities(seg_tag)
                word_length.append(len(entities))
                sentences.append([
                    ''.join(text[entity[1]:entity[2] + 1]).strip()
                    for entity in entities
                ])
                word_idx.append(
                    torch.as_tensor([entity[1] for entity in entities],
                                    device=self.device))

        word_idx = torch.nn.utils.rnn.pad_sequence(word_idx, batch_first=True)
        word_idx = word_idx.unsqueeze(-1).expand(-1, -1,
                                                 hidden.shape[-1])  # 展开

        word_input = torch.gather(hidden, dim=1,
                                  index=word_idx)  # 每个word第一个char的向量

        if len(self.dep_vocab) + len(self.sdp_vocab) > 0:
            word_cls_input = torch.cat([cls, word_input], dim=1)
            word_cls_mask = length_to_mask(
                torch.as_tensor(word_length, device=self.device) + 1)
            word_cls_mask[:, 0] = False
        else:
            word_cls_input, word_cls_mask = None, None

        return sentences, {
            'word_cls': cls,
            'word_input': word_input,
            'word_length': word_length,
            'word_cls_input': word_cls_input,
            'word_cls_mask': word_cls_mask
        }
Example #17
0
 def distill(self, inputs, targets, temperature_calc, distill_loss, gold=None):
     mask = length_to_mask(gold['text_length'])
     logits = inputs[mask]
     logits_T = targets[mask]
     temperature = temperature_calc(logits, logits_T)
     return distill_loss(logits, logits_T, temperature)
Example #18
0
 def forward(self, inputs, targets):
     mask = length_to_mask(targets['text_length'])
     target = targets['word_idn']
     loss = F.cross_entropy(inputs[mask], target[mask], reduction=self.reduction)
     return loss
Example #19
0
    def seg(self, inputs: List[str]):
        tokenizerd = self.tokenizer.batch_encode_plus(
            inputs, return_tensors=self.tensor, padding=True)
        cls, hidden, seg, length = self._seg(tokenizerd)

        # merge segments with maximum forward matching
        if self.trie.is_init:
            matches = self.seg_with_dict(inputs, tokenizerd)
            for sent_match, sent_seg in zip(matches, seg):
                for start, end in sent_match:
                    sent_seg[start] = 0
                    sent_seg[start + 1:end] = 1
                    if end < len(sent_seg):
                        sent_seg[end] = 0

        segment_output = convert_idx_to_name(seg, length, self.seg_vocab)
        if USE_PLUGIN:
            offsets = [
                list(filter(lambda x: x != (0, 0), encodings.offsets))
                for encodings in tokenizerd.encodings
            ]
            words = [
                list(filter(lambda x: x is not None, encodings.words))
                for encodings in tokenizerd.encodings
            ]
            sentences, word_idx, word_length = segment_decode(
                inputs, segment_output, offsets, words)
            word_idx = [
                torch.as_tensor(idx, device=self.device) for idx in word_idx
            ]
        else:
            sentences = []
            word_idx = []
            word_length = []

            for source_text, encoding, sentence_seg_tag in zip(
                    inputs, tokenizerd.encodings, segment_output):
                text = [
                    source_text[start:end]
                    for start, end in encoding.offsets[1:-1] if end != 0
                ]

                last_word = 0
                for idx, word in enumerate(encoding.words[1:-1]):
                    if word is None or is_chinese_char(text[idx][-1]):
                        continue
                    if word != last_word:
                        text[idx] = ' ' + text[idx]
                        last_word = word
                    else:
                        sentence_seg_tag[idx] = WORD_MIDDLE

                entities = get_entities(sentence_seg_tag)
                word_length.append(len(entities))
                sentences.append([
                    ''.join(text[entity[1]:entity[2] + 1]).strip()
                    for entity in entities
                ])
                word_idx.append(
                    torch.as_tensor([entity[1] for entity in entities],
                                    device=self.device))

        word_idx = torch.nn.utils.rnn.pad_sequence(word_idx, batch_first=True)
        word_idx = word_idx.unsqueeze(-1).expand(-1, -1,
                                                 hidden.shape[-1])  # 展开

        word_input = torch.gather(hidden, dim=1,
                                  index=word_idx)  # 每个word第一个char的向量

        word_cls_input = torch.cat([cls, word_input], dim=1)
        word_cls_mask = length_to_mask(
            torch.as_tensor(word_length, device=self.device) + 1)
        word_cls_mask[:, 0] = False  # ignore the first token of each sentence
        return sentences, {
            'word_cls': cls,
            'word_input': word_input,
            'word_length': word_length,
            'word_cls_input': word_cls_input,
            'word_cls_mask': word_cls_mask
        }
Example #20
0
File: ltp.py Project: wgmzone/ltp
    def seg(self, inputs: Union[List[str], List[List[str]]], truncation: bool = True, is_preseged=False):
        """
        分词

        Args:
            inputs: 句子列表
            truncation: 是否对过长的句子进行截断,如果为 False 可能会抛出异常
            is_preseged:  是否已经进行过分词

        Returns:
            words: 分词后的序列
            hidden: 用于其他任务的中间表示
        """
        tokenized = self.tokenizer.batch_encode_plus(
            inputs, padding=True, truncation=truncation,
            return_tensors=self.tensor, max_length=self.max_length,
            is_pretokenized=is_preseged
        )
        cls, hidden, seg, lengths = self._seg(tokenized, is_preseged=is_preseged)

        # merge segments with maximum forward matching
        if self.trie.is_init and not is_preseged:
            matches = self.seg_with_dict(inputs, tokenized)
            for sent_match, sent_seg in zip(matches, seg):
                for start, end in sent_match:
                    sent_seg[start] = 0
                    sent_seg[start + 1:end] = 1
                    if end < len(sent_seg):
                        sent_seg[end] = 0

        if is_preseged:
            sentences = inputs
            word_length = [len(sentence) for sentence in sentences]

            word_idx = []
            for encodings in tokenized.encodings:
                sentence_word_idx = []
                for idx, (start, end) in enumerate(encodings.offsets[1:]):
                    if start == 0 and end != 0:
                        sentence_word_idx.append(idx)
                word_idx.append(torch.as_tensor(sentence_word_idx, device=self.device))
        else:
            segment_output = convert_idx_to_name(seg, lengths, self.seg_vocab)
            sentences = []
            word_idx = []
            word_length = []

            for source_text, length, encoding, seg_tag in zip(inputs, lengths, tokenized.encodings, segment_output):
                words = encoding.words[1:length + 1]
                offsets = encoding.offsets[1:length + 1]
                text = [source_text[start:end] for start, end in offsets]

                for idx in range(1, length):
                    current_beg = offsets[idx][0]
                    forward_end = offsets[idx - 1][-1]
                    if forward_end < current_beg:
                        text[idx] = source_text[forward_end:current_beg] + text[idx]
                    if words[idx - 1] == words[idx]:
                        seg_tag[idx] = WORD_MIDDLE

                entities = get_entities(seg_tag)
                word_length.append(len(entities))
                sentences.append([''.join(text[entity[1]:entity[2] + 1]).strip() for entity in entities])
                word_idx.append(torch.as_tensor([entity[1] for entity in entities], device=self.device))

        word_idx = torch.nn.utils.rnn.pad_sequence(word_idx, batch_first=True)
        word_idx = word_idx.unsqueeze(-1).expand(-1, -1, hidden.shape[-1])  # 展开

        word_input = torch.gather(hidden, dim=1, index=word_idx)  # 每个word第一个char的向量

        word_cls_input = torch.cat([cls, word_input], dim=1)
        word_cls_mask = length_to_mask(torch.as_tensor(word_length, device=self.device) + 1)
        word_cls_mask[:, 0] = False  # ignore the first token of each sentence
        return sentences, {
            'word_cls': cls, 'word_input': word_input, 'word_length': word_length,
            'word_cls_input': word_cls_input, 'word_cls_mask': word_cls_mask
        }