def step(self, y_pred: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], y: Tuple[torch.Tensor, torch.Tensor]): arc_pred, label_pred, seq_len = y_pred mask = length_to_mask(seq_len + 1) mask[:, 0] = False if self._eisner: from ltp.utils import eisner arc_pred = eisner(arc_pred, mask) else: arc_pred = torch.argmax(arc_pred, dim=-1) label_pred = torch.argmax(label_pred, dim=-1) arc_real, label_real = y label_pred = label_pred.gather(-1, arc_pred.unsqueeze(-1)).squeeze(-1) mask = mask.narrow(-1, 1, mask.size(1) - 1) arc_pred = arc_pred.narrow(-1, 1, arc_pred.size(1) - 1) label_pred = label_pred.narrow(-1, 1, label_pred.size(1) - 1) head_true = (arc_pred == arc_real)[mask] label_true = (label_pred == label_real)[mask] self._head_true += torch.sum(head_true).item() self._label_true += torch.sum(label_true).item() self._union_true += torch.sum(label_true[head_true]).item() self._all += torch.sum(mask).item()
def srl(self, hidden: dict, keep_empty=True): # 语义角色标注 word_length = torch.as_tensor(hidden['word_length'], device=hidden['word_input'].device) word_mask = length_to_mask(word_length) srl_output, srl_length, crf = self.model.srl_decoder( hidden['word_input'], hidden['word_length']) mask = word_mask.unsqueeze_(-1).expand(-1, -1, word_mask.size(1)) mask = (mask & mask.transpose(-1, -2)).flatten(end_dim=1) index = mask[:, 0] mask = mask[index] srl_entities = crf.decode(srl_output.flatten(end_dim=1)[index], mask) srl_entities = self._get_entities_with_list(srl_entities, self.srl_vocab) srl_labels_res = [] for length in srl_length: srl_labels_res.append([]) curr_srl_labels, srl_entities = srl_entities[: length], srl_entities[ length:] srl_labels_res[-1].extend(curr_srl_labels) if not keep_empty: srl_labels_res = [[(idx, labels) for idx, labels in enumerate(srl_labels) if len(labels)] for srl_labels in srl_labels_res] return srl_labels_res
def forward(self, x: Tensor, length: Tensor, gold: Optional = None): """ :param x: batch_size x max_len :param length: sequence length, B """ mask = length_to_mask(length, dtype=torch.long) for layer in self.layers: x = layer(x, mask) return x, length, gold
def distill(self, inputs, targets, temperature_calc, distill_loss, gold=None): emissions, seq_lens, crf = inputs emissions_T, _, crf_T = targets mask = length_to_mask(seq_lens) mask = mask.unsqueeze_(-1).expand(-1, -1, mask.size(1)) mask = mask & mask.transpose(-1, -2) logits_loss = F.mse_loss(emissions[mask], emissions_T[mask]) crf_loss = F.mse_loss(crf.transitions, crf_T.transitions) + \ F.mse_loss(crf.start_transitions, crf_T.start_transitions) + \ F.mse_loss(crf.end_transitions, crf_T.end_transitions) return logits_loss + crf_loss
def seg(self, inputs: List[str]): length = torch.as_tensor([len(text) for text in inputs], device=self.device) tokenizerd = self.tokenizer.batch_encode_plus(inputs, return_tensors='pt') pretrained_output, *_ = self.model.pretrained( input_ids=tokenizerd['input_ids'].to(self.device), attention_mask=tokenizerd['attention_mask'].to(self.device), token_type_ids=tokenizerd['token_type_ids'].to(self.device) ) # remove [CLS] [SEP] word_cls = pretrained_output[:, :1] char_input = torch.narrow(pretrained_output, 1, 1, pretrained_output.size(1) - 2) segment_output = torch.argmax(self.model.seg_decoder(char_input), dim=-1).cpu().numpy() segment_output = self._convert_idx_to_name(segment_output, length, self.seg_vocab) sentences = [] word_idx = [] word_length = [] for source_text, encoding, sentence_seg_tag in zip(inputs, tokenizerd.encodings, segment_output): text = [source_text[start:end] for start, end in encoding.offsets[1:-1] if end != 0] last_word = 0 for idx, word in enumerate(encoding.words[1:-1]): if word is None or is_chinese_char(text[idx][-1]): continue if word != last_word: text[idx] = ' ' + text[idx] last_word = word else: sentence_seg_tag[idx] = WORD_MIDDLE entities = get_entities(sentence_seg_tag) word_length.append(len(entities)) sentences.append([''.join(text[entity[1]:entity[2] + 1]).strip() for entity in entities]) word_idx.append(torch.as_tensor([entity[1] for entity in entities], device=self.device)) word_idx = torch.nn.utils.rnn.pad_sequence(word_idx, batch_first=True) word_idx = word_idx.unsqueeze(-1).expand(-1, -1, char_input.shape[-1]) word_input = torch.gather(char_input, dim=1, index=word_idx) word_cls_input = torch.cat([word_cls, word_input], dim=1) word_cls_mask = length_to_mask(torch.as_tensor(word_length, device=self.device) + 1) word_cls_mask[:, 0] = False # ignore the first token of each sentence return sentences, { 'word_cls': word_cls, 'word_input': word_input, 'word_length': word_length, 'word_cls_input': word_cls_input, 'word_cls_mask': word_cls_mask }
def seg(self, inputs: List[str]): length = [len(text) for text in inputs] tokenizerd = self.tokenizer.batch_encode_plus(inputs, pad_to_max_length=True) pretrained_inputs = {key: convert(value) for key, value in tokenizerd.items()} cls, hidden, seg = self.onnx.run(None, pretrained_inputs) segment_output = self._convert_idx_to_name(seg, length, self.seg_vocab) word_cls = torch.as_tensor(cls, device=self.device) char_input = torch.as_tensor(hidden, device=self.device) sentences = [] word_idx = [] word_length = [] for source_text, encoding, sentence_seg_tag in zip(inputs, tokenizerd.encodings, segment_output): text = [source_text[start:end] for start, end in encoding.offsets[1:-1] if end != 0] last_word = 0 for idx, word in enumerate(encoding.words[1:-1]): if word is None or self._is_chinese_char(text[idx][-1]): continue if word != last_word: text[idx] = ' ' + text[idx] last_word = word else: sentence_seg_tag[idx] = WORD_MIDDLE entities = get_entities(sentence_seg_tag) word_length.append(len(entities)) sentences.append([''.join(text[entity[1]:entity[2] + 1]).lstrip() for entity in entities]) word_idx.append(torch.as_tensor([entity[1] for entity in entities], device=self.device)) word_idx = torch.nn.utils.rnn.pad_sequence(word_idx, batch_first=True) word_idx = word_idx.unsqueeze(-1).expand(-1, -1, char_input.shape[-1]) word_input = torch.gather(char_input, dim=1, index=word_idx) word_cls_input = torch.cat([word_cls, word_input], dim=1) word_cls_mask = length_to_mask(torch.as_tensor(word_length, device=self.device) + 1) word_cls_mask[:, 0] = False # ignore the first token of each sentence return sentences, { 'word_cls': word_cls, 'word_input': word_input, 'word_length': word_length, 'word_cls_input': word_cls_input, 'word_cls_mask': word_cls_mask }
def predict(self, inputs, pred): srl_output, srl_length = pred mask = length_to_mask(srl_length) mask = mask.unsqueeze_(-1).expand(-1, -1, mask.size(1)) mask = (mask & mask.transpose(-1, -2)).flatten(end_dim=1) index = mask[:, 0] mask = mask[index] srl_output = srl_output.flatten(end_dim=1)[index] srl_labels = torch.argmax(srl_output, dim=-1).cpu().numpy() srl_labels = self._convert_idx_to_name(srl_labels, mask.sum(dim=1)) # srl_labels_res = [] # for length in srl_length: # srl_labels_res.append([]) # curr_srl_labels, srl_labels = srl_labels[:length], srl_labels[length:] # srl_labels_res[-1].extend([get_entities(labels) for labels in curr_srl_labels]) return srl_labels
def distill(self, inputs, targets, temperature_calc, distill_loss, gold=None): arc_scores, rel_scores, seq_lens = inputs arc_scores_T, rel_scores_T, _ = targets mask = length_to_mask(seq_lens + 1) mask[:, 0] = False # ignore the first token of each sentence arc_logits = select_logits_with_mask(arc_scores, mask) arc_logits_T = select_logits_with_mask(arc_scores_T, mask) arc_temperature = temperature_calc(arc_logits, arc_logits_T) rel_logits = select_logits_with_mask(rel_scores, mask) rel_logits_T = select_logits_with_mask(rel_scores_T, mask) rel_temperature = temperature_calc(rel_logits, rel_logits_T) loss = 2 * ((1 - self.loss_interpolation) * self.kd_ce_loss(arc_logits, arc_logits_T, arc_temperature) + self.loss_interpolation * self.kd_ce_loss(rel_logits, rel_logits_T, rel_temperature)) return loss
def forward(self, inputs, targets): emissions, seq_lens, crf = inputs rel_gold, rel_gold_set = targets mask = length_to_mask(seq_lens) mask = mask.unsqueeze_(-1).expand_as(rel_gold) mask = mask & mask.transpose(-1, -2) mask = mask.flatten(end_dim=1) index = mask[:, 0] mask = mask[index] emissions = emissions.flatten(end_dim=1)[index] rel_gold = rel_gold.flatten(end_dim=1)[index] if self.cross_entropy: cross_entropy = F.cross_entropy(emissions[mask], rel_gold[mask]) crf_loss = crf.forward(emissions=emissions, tags=rel_gold, mask=mask, reduction=self.reduction) return cross_entropy - crf_loss else: return -crf.forward(emissions=emissions, tags=rel_gold, mask=mask, reduction=self.reduction)
def forward(self, inputs, targets): arcs, rels = targets arc_scores, rel_scores, seq_lens = inputs mask = length_to_mask(seq_lens + 1, dtype=torch.float) mask[:, 0] = 0 # ignore the first token of each sentence mask = mask.unsqueeze(-1) mask = mask.expand_as(arcs) arc_loss = F.binary_cross_entropy_with_logits( arc_scores, arcs, weight=mask, reduction=self.reduction ) num_tags = rel_scores.shape[-1] rel_loss = F.cross_entropy( rel_scores.contiguous().view((-1, num_tags)), rels.contiguous().view(-1), weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction ) loss = 2 * ((1 - self.loss_interpolation) * arc_loss + self.loss_interpolation * rel_loss) return loss
def distill(self, inputs, targets, temperature_calc, distill_loss, gold=None): arc_scores, rel_scores, seq_lens = inputs arc_scores_T, rel_scores_T, _ = targets mask = length_to_mask(seq_lens + 1) mask[:, 0] = False arc_mask = mask.unsqueeze(-1).expand_as(arc_scores) arc_logits = torch.sigmoid(arc_scores)[arc_mask] arc_logits_T = torch.sigmoid(arc_scores_T)[arc_mask] arc_temperature = temperature_calc(arc_logits, arc_logits_T) rel_logits = select_logits_with_mask(rel_scores, mask) rel_logits_T = select_logits_with_mask(rel_scores_T, mask) rel_temperature = temperature_calc(rel_logits, rel_logits_T) loss = 2 * ((1 - self.loss_interpolation) * F.mse_loss(arc_logits / arc_temperature, arc_logits_T / arc_temperature) + self.loss_interpolation * self.kd_ce_loss(rel_logits, rel_logits_T, rel_temperature)) return loss
def step(self, y_pred: Tuple[torch.Tensor, torch.Tensor, Any], y: Tuple[torch.Tensor, set]): rel_gold, rels_gold_set = y rels_scores, seq_lens, crf = y_pred mask = length_to_mask(seq_lens) mask = mask.unsqueeze_(-1).expand(-1, -1, mask.size(1)) mask = mask & mask.transpose(-1, -2) mask = mask.flatten(end_dim=1) index = mask[:, 0] rel_gold = rel_gold.flatten(end_dim=1)[index] mask = mask[index] pred_entities = crf.decode(rels_scores.flatten(end_dim=1)[index], mask) rel_entities = self.get_entities(rel_gold[mask]) pred_entities = self.get_entities_with_list(pred_entities) self.nb_correct += len(rel_entities & pred_entities) self.nb_pred += len(pred_entities) self.nb_true += len(rel_entities)
def step(self, y_pred: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], y: Tuple[torch.Tensor, torch.Tensor]): arc_pred, label_pred, seq_len = y_pred arc_real, label_real = y arc_real = arc_real > 0.5 arc_pred = torch.sigmoid(arc_pred) > 0.5 # to [B, L+1, L+1] label_pred = torch.argmax(label_pred, dim=-1) mask = length_to_mask(seq_len + 1) mask[:, 0] = False # ignore the first token of each sentence mask = mask.unsqueeze(-1).expand_as(arc_pred) arc_pred[mask == False] = False true_entities = self.get_entities(arc_real, label_real) pred_entities = self.get_entities(arc_pred, label_pred) self.nb_correct += len(true_entities & pred_entities) self.nb_pred += len(pred_entities) self.nb_true += len(true_entities)
def forward(self, inputs, targets): arcs, rels = targets arc_scores, rel_scores, seq_lens = inputs mask = length_to_mask(seq_lens + 1) mask[:, 0] = False # ignore the first token of each sentence arc_scores, rel_scores = arc_scores[mask], rel_scores[mask] # for taget not bos mask = torch.narrow(mask, dim=-1, start=1, length=mask.size(1) - 1) arcs, rels = arcs[mask], rels[mask] rel_scores = rel_scores[torch.arange(len(arcs)), arcs] arc_loss = F.cross_entropy( arc_scores, arcs, weight=None, ignore_index=self.ignore_index, reduction=self.reduction ) rel_loss = F.cross_entropy( rel_scores, rels, weight=None, ignore_index=self.ignore_index, reduction=self.reduction ) loss = 2 * ((1 - self.loss_interpolation) * arc_loss + self.loss_interpolation * rel_loss) return loss
def step(self, y_pred, y: dict): mask = ~ length_to_mask(y['text_length']) target = y['word_idn'] target[mask] = -1 super(Segment, self).step(y_pred, target)
def seg(self, inputs: Union[List[str], List[List[str]]], truncation: bool = True, is_preseged=False): """ 分词 Args: inputs: 句子列表 truncation: 是否对过长的句子进行截断,如果为 False 可能会抛出异常 is_preseged: 是否已经进行过分词 Returns: words: 分词后的序列 hidden: 用于其他任务的中间表示 """ if transformers_version.major >= 3 and transformers_version.major > 1: kwargs = {'is_split_into_words': is_preseged} else: kwargs = {'is_pretokenized': is_preseged} tokenized = self.tokenizer.batch_encode_plus( inputs, padding=True, truncation=truncation, return_tensors=self.tensor, max_length=self.max_length, **kwargs) cls, hidden, seg, lengths = self._seg(tokenized, is_preseged=is_preseged) batch_prefix = [[ word_idx != encoding.words[idx - 1] for idx, word_idx in enumerate(encoding.words) if word_idx is not None ] for encoding in tokenized.encodings] # merge segments with maximum forward matching if self.trie.is_init and not is_preseged: matches = self.seg_with_dict(inputs, tokenized, batch_prefix) for sent_match, sent_seg in zip(matches, seg): for start, end in sent_match: sent_seg[start] = self.seg_vocab_dict[WORD_START] sent_seg[start + 1:end] = self.seg_vocab_dict[WORD_MIDDLE] if end < len(sent_seg): sent_seg[end] = self.seg_vocab_dict[WORD_START] if is_preseged: sentences = inputs word_length = [len(sentence) for sentence in sentences] word_idx = [] for encodings in tokenized.encodings: sentence_word_idx = [] for idx, (start, end) in enumerate(encodings.offsets[1:]): if start == 0 and end != 0: sentence_word_idx.append(idx) word_idx.append( torch.as_tensor(sentence_word_idx, device=self.device)) else: segment_output = convert_idx_to_name(seg, lengths, self.seg_vocab) sentences = [] word_idx = [] word_length = [] for source_text, length, encoding, seg_tag, preffix in \ zip(inputs, lengths, tokenized.encodings, segment_output, batch_prefix): offsets = encoding.offsets[1:length + 1] text = [] last_offset = None for start, end in offsets: text.append('' if last_offset == ( start, end) else source_text[start:end]) last_offset = (start, end) for idx in range(1, length): current_beg = offsets[idx][0] forward_end = offsets[idx - 1][-1] if forward_end < current_beg: text[idx] = source_text[ forward_end:current_beg] + text[idx] if not preffix[idx]: seg_tag[idx] = WORD_MIDDLE entities = get_entities(seg_tag) word_length.append(len(entities)) sentences.append([ ''.join(text[entity[1]:entity[2] + 1]).strip() for entity in entities ]) word_idx.append( torch.as_tensor([entity[1] for entity in entities], device=self.device)) word_idx = torch.nn.utils.rnn.pad_sequence(word_idx, batch_first=True) word_idx = word_idx.unsqueeze(-1).expand(-1, -1, hidden.shape[-1]) # 展开 word_input = torch.gather(hidden, dim=1, index=word_idx) # 每个word第一个char的向量 if len(self.dep_vocab) + len(self.sdp_vocab) > 0: word_cls_input = torch.cat([cls, word_input], dim=1) word_cls_mask = length_to_mask( torch.as_tensor(word_length, device=self.device) + 1) word_cls_mask[:, 0] = False else: word_cls_input, word_cls_mask = None, None return sentences, { 'word_cls': cls, 'word_input': word_input, 'word_length': word_length, 'word_cls_input': word_cls_input, 'word_cls_mask': word_cls_mask }
def distill(self, inputs, targets, temperature_calc, distill_loss, gold=None): mask = length_to_mask(gold['text_length']) logits = inputs[mask] logits_T = targets[mask] temperature = temperature_calc(logits, logits_T) return distill_loss(logits, logits_T, temperature)
def forward(self, inputs, targets): mask = length_to_mask(targets['text_length']) target = targets['word_idn'] loss = F.cross_entropy(inputs[mask], target[mask], reduction=self.reduction) return loss
def seg(self, inputs: List[str]): tokenizerd = self.tokenizer.batch_encode_plus( inputs, return_tensors=self.tensor, padding=True) cls, hidden, seg, length = self._seg(tokenizerd) # merge segments with maximum forward matching if self.trie.is_init: matches = self.seg_with_dict(inputs, tokenizerd) for sent_match, sent_seg in zip(matches, seg): for start, end in sent_match: sent_seg[start] = 0 sent_seg[start + 1:end] = 1 if end < len(sent_seg): sent_seg[end] = 0 segment_output = convert_idx_to_name(seg, length, self.seg_vocab) if USE_PLUGIN: offsets = [ list(filter(lambda x: x != (0, 0), encodings.offsets)) for encodings in tokenizerd.encodings ] words = [ list(filter(lambda x: x is not None, encodings.words)) for encodings in tokenizerd.encodings ] sentences, word_idx, word_length = segment_decode( inputs, segment_output, offsets, words) word_idx = [ torch.as_tensor(idx, device=self.device) for idx in word_idx ] else: sentences = [] word_idx = [] word_length = [] for source_text, encoding, sentence_seg_tag in zip( inputs, tokenizerd.encodings, segment_output): text = [ source_text[start:end] for start, end in encoding.offsets[1:-1] if end != 0 ] last_word = 0 for idx, word in enumerate(encoding.words[1:-1]): if word is None or is_chinese_char(text[idx][-1]): continue if word != last_word: text[idx] = ' ' + text[idx] last_word = word else: sentence_seg_tag[idx] = WORD_MIDDLE entities = get_entities(sentence_seg_tag) word_length.append(len(entities)) sentences.append([ ''.join(text[entity[1]:entity[2] + 1]).strip() for entity in entities ]) word_idx.append( torch.as_tensor([entity[1] for entity in entities], device=self.device)) word_idx = torch.nn.utils.rnn.pad_sequence(word_idx, batch_first=True) word_idx = word_idx.unsqueeze(-1).expand(-1, -1, hidden.shape[-1]) # 展开 word_input = torch.gather(hidden, dim=1, index=word_idx) # 每个word第一个char的向量 word_cls_input = torch.cat([cls, word_input], dim=1) word_cls_mask = length_to_mask( torch.as_tensor(word_length, device=self.device) + 1) word_cls_mask[:, 0] = False # ignore the first token of each sentence return sentences, { 'word_cls': cls, 'word_input': word_input, 'word_length': word_length, 'word_cls_input': word_cls_input, 'word_cls_mask': word_cls_mask }
def seg(self, inputs: Union[List[str], List[List[str]]], truncation: bool = True, is_preseged=False): """ 分词 Args: inputs: 句子列表 truncation: 是否对过长的句子进行截断,如果为 False 可能会抛出异常 is_preseged: 是否已经进行过分词 Returns: words: 分词后的序列 hidden: 用于其他任务的中间表示 """ tokenized = self.tokenizer.batch_encode_plus( inputs, padding=True, truncation=truncation, return_tensors=self.tensor, max_length=self.max_length, is_pretokenized=is_preseged ) cls, hidden, seg, lengths = self._seg(tokenized, is_preseged=is_preseged) # merge segments with maximum forward matching if self.trie.is_init and not is_preseged: matches = self.seg_with_dict(inputs, tokenized) for sent_match, sent_seg in zip(matches, seg): for start, end in sent_match: sent_seg[start] = 0 sent_seg[start + 1:end] = 1 if end < len(sent_seg): sent_seg[end] = 0 if is_preseged: sentences = inputs word_length = [len(sentence) for sentence in sentences] word_idx = [] for encodings in tokenized.encodings: sentence_word_idx = [] for idx, (start, end) in enumerate(encodings.offsets[1:]): if start == 0 and end != 0: sentence_word_idx.append(idx) word_idx.append(torch.as_tensor(sentence_word_idx, device=self.device)) else: segment_output = convert_idx_to_name(seg, lengths, self.seg_vocab) sentences = [] word_idx = [] word_length = [] for source_text, length, encoding, seg_tag in zip(inputs, lengths, tokenized.encodings, segment_output): words = encoding.words[1:length + 1] offsets = encoding.offsets[1:length + 1] text = [source_text[start:end] for start, end in offsets] for idx in range(1, length): current_beg = offsets[idx][0] forward_end = offsets[idx - 1][-1] if forward_end < current_beg: text[idx] = source_text[forward_end:current_beg] + text[idx] if words[idx - 1] == words[idx]: seg_tag[idx] = WORD_MIDDLE entities = get_entities(seg_tag) word_length.append(len(entities)) sentences.append([''.join(text[entity[1]:entity[2] + 1]).strip() for entity in entities]) word_idx.append(torch.as_tensor([entity[1] for entity in entities], device=self.device)) word_idx = torch.nn.utils.rnn.pad_sequence(word_idx, batch_first=True) word_idx = word_idx.unsqueeze(-1).expand(-1, -1, hidden.shape[-1]) # 展开 word_input = torch.gather(hidden, dim=1, index=word_idx) # 每个word第一个char的向量 word_cls_input = torch.cat([cls, word_input], dim=1) word_cls_mask = length_to_mask(torch.as_tensor(word_length, device=self.device) + 1) word_cls_mask[:, 0] = False # ignore the first token of each sentence return sentences, { 'word_cls': cls, 'word_input': word_input, 'word_length': word_length, 'word_cls_input': word_cls_input, 'word_cls_mask': word_cls_mask }