def __init__(self, word2id, gram2id, feature2id, labelmap, processor, hpara, args): super().__init__() self.spec = locals() self.spec.pop("self") self.spec.pop("__class__") self.spec.pop('args') self.word2id = word2id self.hpara = hpara self.max_seq_length = self.hpara['max_seq_length'] self.max_ngram_size = self.hpara['max_ngram_size'] self.use_attention = self.hpara['use_attention'] self.gram2id = gram2id self.feature2id = feature2id self.feature_processor = processor if self.hpara['use_attention']: self.source = self.hpara['source'] self.feature_flag = self.hpara['feature_flag'] else: self.source = None self.feature_flag = None self.labelmap = labelmap self.num_labels = len(self.labelmap) + 1 self.bert_tokenizer = None self.bert = None self.zen_tokenizer = None self.zen = None self.zen_ngram_dict = None if self.hpara['use_bert']: if args.do_train: cache_dir = args.cache_dir if args.cache_dir else os.path.join( str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank)) self.bert_tokenizer = BertTokenizer.from_pretrained( args.bert_model, do_lower_case=self.hpara['do_lower_case']) self.bert = BertModel.from_pretrained(args.bert_model, cache_dir=cache_dir) self.hpara['bert_tokenizer'] = self.bert_tokenizer self.hpara['config'] = self.bert.config else: self.bert_tokenizer = self.hpara['bert_tokenizer'] self.bert = BertModel(self.hpara['config']) hidden_size = self.bert.config.hidden_size self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob) elif self.hpara['use_zen']: if args.do_train: cache_dir = args.cache_dir if args.cache_dir else os.path.join( str(zen.PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank)) self.zen_tokenizer = zen.BertTokenizer.from_pretrained( args.bert_model, do_lower_case=self.hpara['do_lower_case']) self.zen_ngram_dict = zen.ZenNgramDict( args.bert_model, tokenizer=self.zen_tokenizer) self.zen = zen.modeling.ZenModel.from_pretrained( args.bert_model, cache_dir=cache_dir) self.hpara['zen_tokenizer'] = self.zen_tokenizer self.hpara['zen_ngram_dict'] = self.zen_ngram_dict self.hpara['config'] = self.zen.config else: self.zen_tokenizer = self.hpara['zen_tokenizer'] self.zen_ngram_dict = self.hpara['zen_ngram_dict'] self.zen = zen.modeling.ZenModel(self.hpara['config']) hidden_size = self.zen.config.hidden_size self.dropout = nn.Dropout(self.zen.config.hidden_dropout_prob) else: raise ValueError() if self.hpara['use_attention']: self.context_attention = Attention(hidden_size, len(self.gram2id)) self.feature_attention = Attention(hidden_size, len(self.feature2id)) self.classifier = nn.Linear(hidden_size * 3, self.num_labels, bias=False) else: self.context_attention = None self.feature_attention = None self.classifier = nn.Linear(hidden_size, self.num_labels, bias=False) self.crf = CRF(tagset_size=self.num_labels - 3, gpu=True) if args.do_train: self.spec['hpara'] = self.hpara
class TwASP(nn.Module): def __init__(self, word2id, gram2id, feature2id, labelmap, processor, hpara, args): super().__init__() self.spec = locals() self.spec.pop("self") self.spec.pop("__class__") self.spec.pop('args') self.word2id = word2id self.hpara = hpara self.max_seq_length = self.hpara['max_seq_length'] self.max_ngram_size = self.hpara['max_ngram_size'] self.use_attention = self.hpara['use_attention'] self.gram2id = gram2id self.feature2id = feature2id self.feature_processor = processor if self.hpara['use_attention']: self.source = self.hpara['source'] self.feature_flag = self.hpara['feature_flag'] else: self.source = None self.feature_flag = None self.labelmap = labelmap self.num_labels = len(self.labelmap) + 1 self.bert_tokenizer = None self.bert = None self.zen_tokenizer = None self.zen = None self.zen_ngram_dict = None if self.hpara['use_bert']: if args.do_train: cache_dir = args.cache_dir if args.cache_dir else os.path.join( str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank)) self.bert_tokenizer = BertTokenizer.from_pretrained( args.bert_model, do_lower_case=self.hpara['do_lower_case']) self.bert = BertModel.from_pretrained(args.bert_model, cache_dir=cache_dir) self.hpara['bert_tokenizer'] = self.bert_tokenizer self.hpara['config'] = self.bert.config else: self.bert_tokenizer = self.hpara['bert_tokenizer'] self.bert = BertModel(self.hpara['config']) hidden_size = self.bert.config.hidden_size self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob) elif self.hpara['use_zen']: if args.do_train: cache_dir = args.cache_dir if args.cache_dir else os.path.join( str(zen.PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank)) self.zen_tokenizer = zen.BertTokenizer.from_pretrained( args.bert_model, do_lower_case=self.hpara['do_lower_case']) self.zen_ngram_dict = zen.ZenNgramDict( args.bert_model, tokenizer=self.zen_tokenizer) self.zen = zen.modeling.ZenModel.from_pretrained( args.bert_model, cache_dir=cache_dir) self.hpara['zen_tokenizer'] = self.zen_tokenizer self.hpara['zen_ngram_dict'] = self.zen_ngram_dict self.hpara['config'] = self.zen.config else: self.zen_tokenizer = self.hpara['zen_tokenizer'] self.zen_ngram_dict = self.hpara['zen_ngram_dict'] self.zen = zen.modeling.ZenModel(self.hpara['config']) hidden_size = self.zen.config.hidden_size self.dropout = nn.Dropout(self.zen.config.hidden_dropout_prob) else: raise ValueError() if self.hpara['use_attention']: self.context_attention = Attention(hidden_size, len(self.gram2id)) self.feature_attention = Attention(hidden_size, len(self.feature2id)) self.classifier = nn.Linear(hidden_size * 3, self.num_labels, bias=False) else: self.context_attention = None self.feature_attention = None self.classifier = nn.Linear(hidden_size, self.num_labels, bias=False) self.crf = CRF(tagset_size=self.num_labels - 3, gpu=True) if args.do_train: self.spec['hpara'] = self.hpara @staticmethod def init_hyper_parameters(args): hyper_parameters = DEFAULT_HPARA.copy() hyper_parameters['max_seq_length'] = args.max_seq_length hyper_parameters['max_ngram_size'] = args.max_ngram_size hyper_parameters['use_bert'] = args.use_bert hyper_parameters['use_zen'] = args.use_zen hyper_parameters['do_lower_case'] = args.do_lower_case hyper_parameters['use_attention'] = args.use_attention hyper_parameters['feature_flag'] = args.feature_flag hyper_parameters['source'] = args.source return hyper_parameters def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, valid_ids=None, attention_mask_label=None, word_seq=None, feature_seq=None, word_matrix=None, feature_matrix=None, input_ngram_ids=None, ngram_position_matrix=None): if self.bert is not None: sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) elif self.zen is not None: sequence_output, _ = self.zen( input_ids, input_ngram_ids=input_ngram_ids, ngram_position_matrix=ngram_position_matrix, token_type_ids=token_type_ids, attention_mask=attention_mask, output_all_encoded_layers=False) else: raise ValueError() if self.context_attention is not None: word_attention = self.context_attention(word_seq, sequence_output, word_matrix) feature_attention = self.feature_attention(feature_seq, sequence_output, feature_matrix) conc = torch.cat( [sequence_output, word_attention, feature_attention], dim=2) else: conc = sequence_output conc = self.dropout(conc) logits = self.classifier(conc) total_loss = self.crf.neg_log_likelihood_loss(logits, attention_mask, labels) scores, tag_seq = self.crf._viterbi_decode(logits, attention_mask) return total_loss, tag_seq @property def model(self): return self.state_dict() @classmethod def from_spec(cls, spec, model, args): spec = spec.copy() res = cls(args=args, **spec) res.load_state_dict(model) return res def load_data(self, data_path, do_predict=False): if do_predict: lines = read_sentence(data_path) else: lines = readfile(data_path) flag = data_path[data_path.rfind('/') + 1:data_path.rfind('.')] data = [] if self.feature_flag is None: for sentence, label in lines: data.append((sentence, label, None, None, None, None)) elif self.feature_flag == 'pos': all_feature_data = self.feature_processor.read_features(data_path, flag=flag) for (sentence, label), feature_list in zip(lines, all_feature_data): word_list = [] syn_feature_list = [] word_matching_position = [] syn_matching_position = [] for token_index, token in enumerate(feature_list): current_token_pos = token['pos'] current_token = token['word'] current_feature = current_token + '_' + current_token_pos if current_token not in self.gram2id: current_token = '<UNK>' if current_feature not in self.feature2id: if current_token_pos not in self.feature2id: current_feature = '<UNK>' else: current_feature = current_token_pos word_list.append(current_token) syn_feature_list.append(current_feature) assert current_token in self.gram2id assert current_feature in self.feature2id char_index_list = token['char_index'] begin_char_index = max(char_index_list[0] - 2, 0) end_char_index = min(char_index_list[-1] + 3, len(sentence)) for i in range(begin_char_index, end_char_index): word_matching_position.append((i, token_index)) syn_matching_position.append((i, token_index)) data.append((sentence, label, word_list, syn_feature_list, word_matching_position, syn_matching_position)) elif self.feature_flag == 'chunk': all_feature_data = self.feature_processor.read_features(data_path, flag=flag) for (sentence, label), feature_list in zip(lines, all_feature_data): word_list = [] syn_feature_list = [] word_matching_position = [] syn_matching_position = [] for token_index, token in enumerate(feature_list): current_token_chunk_tag = token['chunk_tags'][-1][ 'chunk_tag'] assert token['chunk_tags'][-1]['height'] == 1 current_token = token['word'] current_feature = current_token + '_' + current_token_chunk_tag if current_token not in self.gram2id: current_token = '<UNK>' if current_feature not in self.feature2id: if current_token_chunk_tag not in self.feature2id: current_feature = '<UNK>' else: current_feature = current_token_chunk_tag word_list.append(current_token) syn_feature_list.append(current_feature) assert current_token in self.gram2id assert current_feature in self.feature2id token_index_range = token['chunk_tags'][-1]['range'] char_index_list = token['char_index'] for i in char_index_list: for j in range(token_index_range[0], token_index_range[1]): word_matching_position.append((i, j)) syn_matching_position.append((i, j)) word_matching_position = list(set(word_matching_position)) syn_matching_position = list(set(syn_matching_position)) data.append((sentence, label, word_list, syn_feature_list, word_matching_position, syn_matching_position)) elif self.feature_flag == 'dep': all_feature_data = self.feature_processor.read_features(data_path, flag=flag) for (sentence, label), feature_list in zip(lines, all_feature_data): word_list = [] syn_feature_list = [] word_matching_position = [] syn_matching_position = [] for token_index, token in enumerate(feature_list): current_token_dep_tag = token['dep'] current_token = token['word'] current_feature = current_token + '_' + current_token_dep_tag if current_token not in self.gram2id: current_token = '<UNK>' if current_feature not in self.feature2id: if current_token_dep_tag not in self.feature2id: current_feature = '<UNK>' else: current_feature = current_token_dep_tag word_list.append(current_token) syn_feature_list.append(current_feature) assert current_token in self.gram2id assert current_feature in self.feature2id if token['governed_index'] < 0: token_index_list = [token_index] char_index_list = token['char_index'] else: governed_index = token['governed_index'] token_index_list = [token_index, governed_index] governed_token = feature_list[governed_index] char_index_list = token['char_index'] + governed_token[ 'char_index'] for i in char_index_list: for j in token_index_list: word_matching_position.append((i, j)) syn_matching_position.append((i, j)) word_matching_position = list(set(word_matching_position)) syn_matching_position = list(set(syn_matching_position)) data.append((sentence, label, word_list, syn_feature_list, word_matching_position, syn_matching_position)) else: raise ValueError() examples = [] for i, (sentence, label, word_list, syn_feature_list, word_matching_position, syn_matching_position) in enumerate(data): guid = "%s-%s" % (flag, i) text_a = ' '.join(sentence) text_b = None if word_list is not None: word = ' '.join(word_list) word_list_len = len(word_list) else: word = None word_list_len = 0 label = label examples.append( InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, word=word, syn_feature=syn_feature_list, word_matrix=word_matching_position, syn_matrix=syn_matching_position, sent_len=len(sentence), word_list_len=word_list_len)) return examples def convert_examples_to_features(self, examples): max_seq_length = min( int(max([e.sent_len for e in examples]) * 1.1 + 2), self.max_seq_length) if self.use_attention: max_ngram_size = max( min(max([e.word_list_len for e in examples]), self.max_ngram_size), 1) features = [] tokenizer = self.bert_tokenizer if self.bert_tokenizer is not None else self.zen_tokenizer for (ex_index, example) in enumerate(examples): textlist = example.text_a.split(' ') labellist = example.label tokens = [] labels = [] valid = [] label_mask = [] for i, word in enumerate(textlist): token = tokenizer.tokenize(word) tokens.extend(token) label_1 = labellist[i] for m in range(len(token)): if m == 0: labels.append(label_1) valid.append(1) label_mask.append(1) else: valid.append(0) if len(tokens) >= max_seq_length - 1: tokens = tokens[0:(max_seq_length - 2)] labels = labels[0:(max_seq_length - 2)] valid = valid[0:(max_seq_length - 2)] label_mask = label_mask[0:(max_seq_length - 2)] ntokens = [] segment_ids = [] label_ids = [] ntokens.append("[CLS]") segment_ids.append(0) valid.insert(0, 1) label_mask.insert(0, 1) label_ids.append(self.labelmap["[CLS]"]) for i, token in enumerate(tokens): ntokens.append(token) segment_ids.append(0) if len(labels) > i: if labels[i] in self.labelmap: label_ids.append(self.labelmap[labels[i]]) else: label_ids.append(self.labelmap['<UNK>']) ntokens.append("[SEP]") segment_ids.append(0) valid.append(1) label_mask.append(1) label_ids.append(self.labelmap["[SEP]"]) input_ids = tokenizer.convert_tokens_to_ids(ntokens) input_mask = [1] * len(input_ids) label_mask = [1] * len(label_ids) while len(input_ids) < max_seq_length: input_ids.append(0) input_mask.append(0) segment_ids.append(0) label_ids.append(0) valid.append(1) label_mask.append(0) while len(label_ids) < max_seq_length: label_ids.append(0) label_mask.append(0) assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length assert len(segment_ids) == max_seq_length assert len(label_ids) == max_seq_length assert len(valid) == max_seq_length assert len(label_mask) == max_seq_length if self.use_attention: wordlist = example.word wordlist = wordlist.split(' ') if len(wordlist) > 0 else [] syn_features = example.syn_feature word_matching_position = example.word_matrix syn_matching_position = example.syn_matrix word_ids = [] feature_ids = [] word_matching_matrix = np.zeros( (max_seq_length, max_ngram_size), dtype=np.int) syn_matching_matrix = np.zeros( (max_seq_length, max_ngram_size), dtype=np.int) if len(wordlist) > max_ngram_size: wordlist = wordlist[:max_ngram_size] syn_features = syn_features[:max_ngram_size] for word in wordlist: if word == '': continue try: word_ids.append(self.gram2id[word]) except KeyError: print(word) print(wordlist) print(textlist) raise KeyError() for feature in syn_features: feature_ids.append(self.feature2id[feature]) while len(word_ids) < max_ngram_size: word_ids.append(0) feature_ids.append(0) for position in word_matching_position: char_p = position[0] + 1 word_p = position[1] if char_p > max_seq_length - 2 or word_p > max_ngram_size - 1: continue else: word_matching_matrix[char_p][word_p] = 1 for position in syn_matching_position: char_p = position[0] + 1 word_p = position[1] if char_p > max_seq_length - 2 or word_p > max_ngram_size - 1: continue else: syn_matching_matrix[char_p][word_p] = 1 assert len(word_ids) == max_ngram_size assert len(feature_ids) == max_ngram_size else: word_ids = None feature_ids = None word_matching_matrix = None syn_matching_matrix = None if self.zen_ngram_dict is not None: ngram_matches = [] # Filter the ngram segment from 2 to 7 to check whether there is a ngram for p in range(2, 8): for q in range(0, len(tokens) - p + 1): character_segment = tokens[q:q + p] # j is the starting position of the ngram # i is the length of the current ngram character_segment = tuple(character_segment) if character_segment in self.zen_ngram_dict.ngram_to_id_dict: ngram_index = self.zen_ngram_dict.ngram_to_id_dict[ character_segment] ngram_matches.append( [ngram_index, q, p, character_segment]) random.shuffle(ngram_matches) max_ngram_in_seq_proportion = math.ceil( (len(tokens) / max_seq_length) * self.zen_ngram_dict.max_ngram_in_seq) if len(ngram_matches) > max_ngram_in_seq_proportion: ngram_matches = ngram_matches[:max_ngram_in_seq_proportion] ngram_ids = [ngram[0] for ngram in ngram_matches] ngram_positions = [ngram[1] for ngram in ngram_matches] ngram_lengths = [ngram[2] for ngram in ngram_matches] ngram_tuples = [ngram[3] for ngram in ngram_matches] ngram_seg_ids = [ 0 if position < (len(tokens) + 2) else 1 for position in ngram_positions ] ngram_mask_array = np.zeros( self.zen_ngram_dict.max_ngram_in_seq, dtype=np.bool) ngram_mask_array[:len(ngram_ids)] = 1 # record the masked positions ngram_positions_matrix = np.zeros( shape=(max_seq_length, self.zen_ngram_dict.max_ngram_in_seq), dtype=np.int32) for i in range(len(ngram_ids)): ngram_positions_matrix[ ngram_positions[i]:ngram_positions[i] + ngram_lengths[i], i] = 1.0 # Zero-pad up to the max ngram in seq length. padding = [0] * (self.zen_ngram_dict.max_ngram_in_seq - len(ngram_ids)) ngram_ids += padding ngram_lengths += padding ngram_seg_ids += padding else: ngram_ids = None ngram_positions_matrix = None ngram_lengths = None ngram_tuples = None ngram_seg_ids = None ngram_mask_array = None features.append( InputFeatures(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, label_id=label_ids, valid_ids=valid, label_mask=label_mask, word_ids=word_ids, syn_feature_ids=feature_ids, word_matching_matrix=word_matching_matrix, syn_matching_matrix=syn_matching_matrix, ngram_ids=ngram_ids, ngram_positions=ngram_positions_matrix, ngram_lengths=ngram_lengths, ngram_tuples=ngram_tuples, ngram_seg_ids=ngram_seg_ids, ngram_masks=ngram_mask_array)) return features def feature2input(self, device, feature): all_input_ids = torch.tensor([f.input_ids for f in feature], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in feature], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in feature], dtype=torch.long) all_label_ids = torch.tensor([f.label_id for f in feature], dtype=torch.long) all_valid_ids = torch.tensor([f.valid_ids for f in feature], dtype=torch.long) all_lmask_ids = torch.tensor([f.label_mask for f in feature], dtype=torch.long) input_ids = all_input_ids.to(device) input_mask = all_input_mask.to(device) segment_ids = all_segment_ids.to(device) label_ids = all_label_ids.to(device) valid_ids = all_valid_ids.to(device) l_mask = all_lmask_ids.to(device) if self.hpara['use_attention']: all_word_ids = torch.tensor([f.word_ids for f in feature], dtype=torch.long) all_feature_ids = torch.tensor( [f.syn_feature_ids for f in feature], dtype=torch.long) all_word_matching_matrix = torch.tensor( [f.word_matching_matrix for f in feature], dtype=torch.float) word_ids = all_word_ids.to(device) feature_ids = all_feature_ids.to(device) word_matching_matrix = all_word_matching_matrix.to(device) else: word_ids = None feature_ids = None word_matching_matrix = None if self.hpara['use_zen']: all_ngram_ids = torch.tensor([f.ngram_ids for f in feature], dtype=torch.long) all_ngram_positions = torch.tensor( [f.ngram_positions for f in feature], dtype=torch.long) # all_ngram_lengths = torch.tensor([f.ngram_lengths for f in train_features], dtype=torch.long) # all_ngram_seg_ids = torch.tensor([f.ngram_seg_ids for f in train_features], dtype=torch.long) # all_ngram_masks = torch.tensor([f.ngram_masks for f in train_features], dtype=torch.long) ngram_ids = all_ngram_ids.to(device) ngram_positions = all_ngram_positions.to(device) else: ngram_ids = None ngram_positions = None return feature_ids, input_ids, input_mask, l_mask, label_ids, ngram_ids, ngram_positions, segment_ids, valid_ids, word_ids, word_matching_matrix
def __init__(self, word2id, label2id, hpara, model_path, department2id=None, disease2id=None): super().__init__() self.word2id = word2id self.department2id = None self.disease2id = None self.label2id = label2id self.party2id = None self.hpara = hpara self.num_labels = len(self.label2id) self.max_seq_length = self.hpara['max_seq_length'] self.use_memory = self.hpara['use_memory'] self.use_department = self.hpara['use_department'] self.use_party = self.hpara['use_party'] self.use_disease = self.hpara['use_disease'] self.decoder = self.hpara['decoder'] self.lstm_hidden_size = self.hpara['lstm_hidden_size'] self.max_dialog_length = self.hpara['max_dialog_length'] self.bert_tokenizer = None self.bert = None self.zen_tokenizer = None self.zen = None self.zen_ngram_dict = None if self.hpara['use_bert']: self.bert_tokenizer = BertTokenizer.from_pretrained( model_path, do_lower_case=self.hpara['do_lower_case']) self.bert = BertModel.from_pretrained(model_path, cache_dir='') hidden_size = self.bert.config.hidden_size self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob) elif self.hpara['use_zen']: self.zen_tokenizer = zen.BertTokenizer.from_pretrained( model_path, do_lower_case=self.hpara['do_lower_case']) self.zen_ngram_dict = zen.ZenNgramDict( model_path, tokenizer=self.zen_tokenizer) self.zen = zen.modeling.ZenModel.from_pretrained(model_path, cache_dir='') hidden_size = self.zen.config.hidden_size self.dropout = nn.Dropout(self.zen.config.hidden_dropout_prob) else: raise ValueError() ori_hidden_size = hidden_size if self.use_memory: self.memory = Memory(hidden_size, len(word2id)) hidden_size = hidden_size * 2 else: self.memory = None if self.use_party: self.party_embedding = nn.Embedding(5, ori_hidden_size) hidden_size += ori_hidden_size self.party2id = {'<PAD>': 0, '<UNK>': 1, 'P': 2, 'D': 3} else: self.party_embedding = None utterance_hidden_size = hidden_size if self.hpara['utterance_encoder'] == 'LSTM': self.utterance_encoder = nn.LSTM(input_size=hidden_size, hidden_size=self.lstm_hidden_size, bidirectional=False, batch_first=True) utterance_hidden_size = self.lstm_hidden_size elif self.hpara['utterance_encoder'] == 'biLSTM': self.utterance_encoder = nn.LSTM(input_size=hidden_size, hidden_size=self.lstm_hidden_size, bidirectional=True, batch_first=True) utterance_hidden_size = self.lstm_hidden_size * 2 else: self.utterance_encoder = None if self.use_department: self.department_embedding = nn.Embedding(len(department2id), utterance_hidden_size) self.department2id = department2id else: self.department_embedding = None if self.use_disease: self.disease_embedding = nn.Embedding(len(disease2id), utterance_hidden_size) self.disease2id = disease2id else: self.disease_embedding = None if self.use_department and self.use_disease: utterance_hidden_size = utterance_hidden_size * 3 elif self.use_department or self.use_disease: utterance_hidden_size = utterance_hidden_size * 2 self.classifier = nn.Linear(utterance_hidden_size, self.num_labels) if self.decoder == 'softmax': self.loss_fct = CrossEntropyLoss(ignore_index=0) elif self.decoder == 'crf': self.crf = CRF(self.num_labels, batch_first=True) else: raise ValueError()
class WMSeg(nn.Module): def __init__(self, word2id, gram2id, labelmap, hpara, args): super().__init__() self.spec = locals() self.spec.pop("self") self.spec.pop("__class__") self.spec.pop('args') self.word2id = word2id self.gram2id = gram2id self.labelmap = labelmap self.hpara = hpara self.num_labels = len(self.labelmap) + 1 self.max_seq_length = self.hpara['max_seq_length'] self.max_ngram_size = self.hpara['max_ngram_size'] self.max_ngram_length = self.hpara['max_ngram_length'] self.bert_tokenizer = None self.bert = None self.zen_tokenizer = None self.zen = None self.zen_ngram_dict = None if self.hpara['use_bert']: if args.do_train: cache_dir = args.cache_dir if args.cache_dir else os.path.join( str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank)) self.bert_tokenizer = BertTokenizer.from_pretrained( args.bert_model, do_lower_case=self.hpara['do_lower_case']) self.bert = BertModel.from_pretrained(args.bert_model, cache_dir=cache_dir) self.hpara['bert_tokenizer'] = self.bert_tokenizer self.hpara['config'] = self.bert.config else: self.bert_tokenizer = self.hpara['bert_tokenizer'] self.bert = BertModel(self.hpara['config']) hidden_size = self.bert.config.hidden_size self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob) elif self.hpara['use_zen']: if args.do_train: cache_dir = args.cache_dir if args.cache_dir else os.path.join( str(zen.PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank)) self.zen_tokenizer = zen.BertTokenizer.from_pretrained( args.bert_model, do_lower_case=self.hpara['do_lower_case']) self.zen_ngram_dict = zen.ZenNgramDict( args.bert_model, tokenizer=self.zen_tokenizer) self.zen = zen.modeling.ZenModel.from_pretrained( args.bert_model, cache_dir=cache_dir) self.hpara['zen_tokenizer'] = self.zen_tokenizer self.hpara['zen_ngram_dict'] = self.zen_ngram_dict self.hpara['config'] = self.zen.config else: self.zen_tokenizer = self.hpara['zen_tokenizer'] self.zen_ngram_dict = self.hpara['zen_ngram_dict'] self.zen = zen.modeling.ZenModel(self.hpara['config']) hidden_size = self.zen.config.hidden_size self.dropout = nn.Dropout(self.zen.config.hidden_dropout_prob) else: raise ValueError() if self.hpara['use_memory']: self.kv_memory = WordKVMN(hidden_size, len(gram2id)) else: self.kv_memory = None self.classifier = nn.Linear(hidden_size, self.num_labels, bias=False) if self.hpara['decoder'] == 'crf': self.crf = CRF(tagset_size=self.num_labels - 3, gpu=True) else: self.crf = None if args.do_train: self.spec['hpara'] = self.hpara def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, valid_ids=None, attention_mask_label=None, word_seq=None, label_value_matrix=None, word_mask=None, input_ngram_ids=None, ngram_position_matrix=None): if self.bert is not None: sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) elif self.zen is not None: sequence_output, _ = self.zen( input_ids, input_ngram_ids=input_ngram_ids, ngram_position_matrix=ngram_position_matrix, token_type_ids=token_type_ids, attention_mask=attention_mask, output_all_encoded_layers=False) else: raise ValueError() if self.kv_memory is not None: sequence_output = self.kv_memory(word_seq, sequence_output, label_value_matrix, word_mask) sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) if self.crf is not None: # crf = CRF(tagset_size=number_of_labels+1, gpu=True) total_loss = self.crf.neg_log_likelihood_loss( logits, attention_mask, labels) scores, tag_seq = self.crf._viterbi_decode(logits, attention_mask) # Only keep active parts of the loss else: loss_fct = CrossEntropyLoss(ignore_index=0) total_loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) tag_seq = torch.argmax(F.log_softmax(logits, dim=2), dim=2) return total_loss, tag_seq @staticmethod def init_hyper_parameters(args): hyper_parameters = DEFAULT_HPARA.copy() hyper_parameters['max_seq_length'] = args.max_seq_length hyper_parameters['max_ngram_size'] = args.max_ngram_size hyper_parameters['max_ngram_length'] = args.max_ngram_length hyper_parameters['use_bert'] = args.use_bert hyper_parameters['use_zen'] = args.use_zen hyper_parameters['do_lower_case'] = args.do_lower_case hyper_parameters['use_memory'] = args.use_memory hyper_parameters['decoder'] = args.decoder return hyper_parameters @property def model(self): return self.state_dict() @classmethod def from_spec(cls, spec, model, args): spec = spec.copy() res = cls(args=args, **spec) res.load_state_dict(model) return res def load_data(self, data_path, do_predict=False): if not do_predict: flag = data_path[data_path.rfind('/') + 1:data_path.rfind('.')] lines = readfile(data_path, flag=flag) else: flag = 'predict' lines = readsentence(data_path) data = [] for sentence, label in lines: if self.kv_memory is not None: word_list = [] matching_position = [] for i in range(len(sentence)): for j in range(self.max_ngram_length): if i + j > len(sentence): break word = ''.join(sentence[i:i + j + 1]) if word in self.gram2id: try: index = word_list.index(word) except ValueError: word_list.append(word) index = len(word_list) - 1 word_len = len(word) for k in range(j + 1): if word_len == 1: l = 'S' elif k == 0: l = 'B' elif k == j: l = 'E' else: l = 'I' matching_position.append((i + k, index, l)) else: word_list = None matching_position = None data.append((sentence, label, word_list, matching_position)) examples = [] for i, (sentence, label, word_list, matching_position) in enumerate(data): guid = "%s-%s" % (flag, i) text_a = ' '.join(sentence) text_b = None if word_list is not None: word = ' '.join(word_list) else: word = None label = label examples.append( InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, word=word, matrix=matching_position)) return examples def convert_examples_to_features(self, examples): max_seq_length = min( int(max([len(e.text_a.split(' ')) for e in examples]) * 1.1 + 2), self.max_seq_length) if self.kv_memory is not None: max_word_size = max( min(max([len(e.word.split(' ')) for e in examples]), self.max_ngram_size), 1) features = [] tokenizer = self.bert_tokenizer if self.bert_tokenizer is not None else self.zen_tokenizer for (ex_index, example) in enumerate(examples): textlist = example.text_a.split(' ') labellist = example.label tokens = [] labels = [] valid = [] label_mask = [] for i, word in enumerate(textlist): token = tokenizer.tokenize(word) tokens.extend(token) label_1 = labellist[i] for m in range(len(token)): if m == 0: valid.append(1) labels.append(label_1) label_mask.append(1) else: valid.append(0) if len(tokens) >= max_seq_length - 1: tokens = tokens[0:(max_seq_length - 2)] labels = labels[0:(max_seq_length - 2)] valid = valid[0:(max_seq_length - 2)] label_mask = label_mask[0:(max_seq_length - 2)] ntokens = [] segment_ids = [] label_ids = [] ntokens.append("[CLS]") segment_ids.append(0) valid.insert(0, 1) label_mask.insert(0, 1) label_ids.append(self.labelmap["[CLS]"]) for i, token in enumerate(tokens): ntokens.append(token) segment_ids.append(0) if len(labels) > i: label_ids.append(self.labelmap[labels[i]]) ntokens.append("[SEP]") segment_ids.append(0) valid.append(1) label_mask.append(1) label_ids.append(self.labelmap["[SEP]"]) input_ids = tokenizer.convert_tokens_to_ids(ntokens) input_mask = [1] * len(input_ids) label_mask = [1] * len(label_ids) while len(input_ids) < max_seq_length: input_ids.append(0) input_mask.append(0) segment_ids.append(0) label_ids.append(0) valid.append(1) label_mask.append(0) while len(label_ids) < max_seq_length: label_ids.append(0) label_mask.append(0) assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length assert len(segment_ids) == max_seq_length assert len(label_ids) == max_seq_length assert len(valid) == max_seq_length assert len(label_mask) == max_seq_length if self.kv_memory is not None: wordlist = example.word wordlist = wordlist.split(' ') if len(wordlist) > 0 else [] matching_position = example.matrix word_ids = [] matching_matrix = np.zeros((max_seq_length, max_word_size), dtype=np.int) if len(wordlist) > max_word_size: wordlist = wordlist[:max_word_size] for word in wordlist: try: word_ids.append(self.gram2id[word]) except KeyError: print(word) print(wordlist) print(textlist) raise KeyError() while len(word_ids) < max_word_size: word_ids.append(0) for position in matching_position: char_p = position[0] + 1 word_p = position[1] if char_p > max_seq_length - 2 or word_p > max_word_size - 1: continue else: matching_matrix[char_p][word_p] = self.labelmap[ position[2]] assert len(word_ids) == max_word_size else: word_ids = None matching_matrix = None if self.zen_ngram_dict is not None: ngram_matches = [] # Filter the ngram segment from 2 to 7 to check whether there is a ngram for p in range(2, 8): for q in range(0, len(tokens) - p + 1): character_segment = tokens[q:q + p] # j is the starting position of the ngram # i is the length of the current ngram character_segment = tuple(character_segment) if character_segment in self.zen_ngram_dict.ngram_to_id_dict: ngram_index = self.zen_ngram_dict.ngram_to_id_dict[ character_segment] ngram_matches.append( [ngram_index, q, p, character_segment]) # random.shuffle(ngram_matches) ngram_matches = sorted(ngram_matches, key=lambda s: s[0]) max_ngram_in_seq_proportion = math.ceil( (len(tokens) / max_seq_length) * self.zen_ngram_dict.max_ngram_in_seq) if len(ngram_matches) > max_ngram_in_seq_proportion: ngram_matches = ngram_matches[:max_ngram_in_seq_proportion] ngram_ids = [ngram[0] for ngram in ngram_matches] ngram_positions = [ngram[1] for ngram in ngram_matches] ngram_lengths = [ngram[2] for ngram in ngram_matches] ngram_tuples = [ngram[3] for ngram in ngram_matches] ngram_seg_ids = [ 0 if position < (len(tokens) + 2) else 1 for position in ngram_positions ] ngram_mask_array = np.zeros( self.zen_ngram_dict.max_ngram_in_seq, dtype=np.bool) ngram_mask_array[:len(ngram_ids)] = 1 # record the masked positions ngram_positions_matrix = np.zeros( shape=(max_seq_length, self.zen_ngram_dict.max_ngram_in_seq), dtype=np.int32) for i in range(len(ngram_ids)): ngram_positions_matrix[ ngram_positions[i]:ngram_positions[i] + ngram_lengths[i], i] = 1.0 # Zero-pad up to the max ngram in seq length. padding = [0] * (self.zen_ngram_dict.max_ngram_in_seq - len(ngram_ids)) ngram_ids += padding ngram_lengths += padding ngram_seg_ids += padding else: ngram_ids = None ngram_positions_matrix = None ngram_lengths = None ngram_tuples = None ngram_seg_ids = None ngram_mask_array = None features.append( InputFeatures(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, label_id=label_ids, valid_ids=valid, label_mask=label_mask, word_ids=word_ids, matching_matrix=matching_matrix, ngram_ids=ngram_ids, ngram_positions=ngram_positions_matrix, ngram_lengths=ngram_lengths, ngram_tuples=ngram_tuples, ngram_seg_ids=ngram_seg_ids, ngram_masks=ngram_mask_array)) return features def feature2input(self, device, feature): all_input_ids = torch.tensor([f.input_ids for f in feature], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in feature], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in feature], dtype=torch.long) all_label_ids = torch.tensor([f.label_id for f in feature], dtype=torch.long) all_valid_ids = torch.tensor([f.valid_ids for f in feature], dtype=torch.long) all_lmask_ids = torch.tensor([f.label_mask for f in feature], dtype=torch.long) input_ids = all_input_ids.to(device) input_mask = all_input_mask.to(device) segment_ids = all_segment_ids.to(device) label_ids = all_label_ids.to(device) valid_ids = all_valid_ids.to(device) l_mask = all_lmask_ids.to(device) if self.hpara['use_memory']: all_word_ids = torch.tensor([f.word_ids for f in feature], dtype=torch.long) all_matching_matrix = torch.tensor( [f.matching_matrix for f in feature], dtype=torch.long) all_word_mask = torch.tensor([f.matching_matrix for f in feature], dtype=torch.float) word_ids = all_word_ids.to(device) matching_matrix = all_matching_matrix.to(device) word_mask = all_word_mask.to(device) else: word_ids = None matching_matrix = None word_mask = None if self.hpara['use_zen']: all_ngram_ids = torch.tensor([f.ngram_ids for f in feature], dtype=torch.long) all_ngram_positions = torch.tensor( [f.ngram_positions for f in feature], dtype=torch.long) # all_ngram_lengths = torch.tensor([f.ngram_lengths for f in train_features], dtype=torch.long) # all_ngram_seg_ids = torch.tensor([f.ngram_seg_ids for f in train_features], dtype=torch.long) # all_ngram_masks = torch.tensor([f.ngram_masks for f in train_features], dtype=torch.long) ngram_ids = all_ngram_ids.to(device) ngram_positions = all_ngram_positions.to(device) else: ngram_ids = None ngram_positions = None return input_ids, input_mask, l_mask, label_ids, matching_matrix, ngram_ids, ngram_positions, segment_ids, valid_ids, word_ids, word_mask
class HET(nn.Module): def __init__(self, word2id, label2id, hpara, model_path, department2id=None, disease2id=None): super().__init__() self.word2id = word2id self.department2id = None self.disease2id = None self.label2id = label2id self.party2id = None self.hpara = hpara self.num_labels = len(self.label2id) self.max_seq_length = self.hpara['max_seq_length'] self.use_memory = self.hpara['use_memory'] self.use_department = self.hpara['use_department'] self.use_party = self.hpara['use_party'] self.use_disease = self.hpara['use_disease'] self.decoder = self.hpara['decoder'] self.lstm_hidden_size = self.hpara['lstm_hidden_size'] self.max_dialog_length = self.hpara['max_dialog_length'] self.bert_tokenizer = None self.bert = None self.zen_tokenizer = None self.zen = None self.zen_ngram_dict = None if self.hpara['use_bert']: self.bert_tokenizer = BertTokenizer.from_pretrained( model_path, do_lower_case=self.hpara['do_lower_case']) self.bert = BertModel.from_pretrained(model_path, cache_dir='') hidden_size = self.bert.config.hidden_size self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob) elif self.hpara['use_zen']: self.zen_tokenizer = zen.BertTokenizer.from_pretrained( model_path, do_lower_case=self.hpara['do_lower_case']) self.zen_ngram_dict = zen.ZenNgramDict( model_path, tokenizer=self.zen_tokenizer) self.zen = zen.modeling.ZenModel.from_pretrained(model_path, cache_dir='') hidden_size = self.zen.config.hidden_size self.dropout = nn.Dropout(self.zen.config.hidden_dropout_prob) else: raise ValueError() ori_hidden_size = hidden_size if self.use_memory: self.memory = Memory(hidden_size, len(word2id)) hidden_size = hidden_size * 2 else: self.memory = None if self.use_party: self.party_embedding = nn.Embedding(5, ori_hidden_size) hidden_size += ori_hidden_size self.party2id = {'<PAD>': 0, '<UNK>': 1, 'P': 2, 'D': 3} else: self.party_embedding = None utterance_hidden_size = hidden_size if self.hpara['utterance_encoder'] == 'LSTM': self.utterance_encoder = nn.LSTM(input_size=hidden_size, hidden_size=self.lstm_hidden_size, bidirectional=False, batch_first=True) utterance_hidden_size = self.lstm_hidden_size elif self.hpara['utterance_encoder'] == 'biLSTM': self.utterance_encoder = nn.LSTM(input_size=hidden_size, hidden_size=self.lstm_hidden_size, bidirectional=True, batch_first=True) utterance_hidden_size = self.lstm_hidden_size * 2 else: self.utterance_encoder = None if self.use_department: self.department_embedding = nn.Embedding(len(department2id), utterance_hidden_size) self.department2id = department2id else: self.department_embedding = None if self.use_disease: self.disease_embedding = nn.Embedding(len(disease2id), utterance_hidden_size) self.disease2id = disease2id else: self.disease_embedding = None if self.use_department and self.use_disease: utterance_hidden_size = utterance_hidden_size * 3 elif self.use_department or self.use_disease: utterance_hidden_size = utterance_hidden_size * 2 self.classifier = nn.Linear(utterance_hidden_size, self.num_labels) if self.decoder == 'softmax': self.loss_fct = CrossEntropyLoss(ignore_index=0) elif self.decoder == 'crf': self.crf = CRF(self.num_labels, batch_first=True) else: raise ValueError() def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, valid_ids=None, attention_mask_label=None, label_mask=None, party_mask=None, party_ids=None, department_ids=None, disease_ids=None, input_ngram_ids=None, ngram_position_matrix=None): batch_size = input_ids.shape[0] dialog_length = input_ids.shape[1] utterance_length = input_ids.shape[2] input_ids = input_ids.view(batch_size * dialog_length, utterance_length) token_type_ids = token_type_ids.view(batch_size * dialog_length, utterance_length) attention_mask = attention_mask.view(batch_size * dialog_length, utterance_length) if self.bert is not None: sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) elif self.zen is not None: ngram_position_matrix = ngram_position_matrix.view( batch_size * dialog_length, utterance_length, -1) input_ngram_ids = input_ngram_ids.view(batch_size * dialog_length, -1) sequence_output, _ = self.zen( input_ids, input_ngram_ids=input_ngram_ids, ngram_position_matrix=ngram_position_matrix, token_type_ids=token_type_ids, attention_mask=attention_mask, output_all_encoded_layers=False) else: raise ValueError() word_embedding_c = None if self.use_memory: word_embedding_c = self.memory.memory_embeddings(input_ids) tmp_sequence_output = sequence_output.view(batch_size, dialog_length, utterance_length, -1) # word_embedding_a = word_embedding_a.view(batch_size, dialog_length, -1) sequence_output = tmp_sequence_output[:, :, 0] tmp_label_mask = torch.stack([label_mask] * sequence_output.shape[-1], 2) sequence_output = torch.mul(sequence_output, tmp_label_mask) if self.use_memory: word_embedding_c = word_embedding_c.view(batch_size, dialog_length, -1) memory_output = self.memory(word_embedding_c, sequence_output, party_mask) sequence_output = torch.cat((sequence_output, memory_output), 2) sequence_output = self.dropout(sequence_output) # if self.use_party: party_embeddings = self.party_embedding(party_ids) sequence_output = torch.cat((sequence_output, party_embeddings), dim=2) # if self.utterance_encoder is not None: self.utterance_encoder.flatten_parameters() utterance_output, _ = self.utterance_encoder(sequence_output) else: utterance_output = sequence_output if self.use_department: department_embeddings = self.department_embedding(department_ids) utterance_output = torch.cat( (utterance_output, department_embeddings), dim=2) if self.use_disease: disease_embeddings = self.disease_embedding(disease_ids) utterance_output = torch.cat( (utterance_output, disease_embeddings), dim=2) tmp_label_mask = torch.stack([label_mask] * utterance_output.shape[-1], 2) utterance_output = torch.mul(utterance_output, tmp_label_mask) logits = self.classifier(utterance_output) if labels is not None: if self.decoder == 'softmax': total_loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) elif self.decoder == 'crf': total_loss = -1 * self.crf( emissions=logits, tags=labels, mask=attention_mask_label) else: raise ValueError() return total_loss else: if self.decoder == 'softmax': scores = torch.argmax(nn.functional.log_softmax(logits, dim=2), dim=2) elif self.decoder == 'crf': scores = self.crf.decode(logits, attention_mask_label)[0] else: raise ValueError() return scores @staticmethod def init_hyper_parameters(args): hyper_parameters = DEFAULT_HPARA.copy() hyper_parameters['max_seq_length'] = args.max_seq_length hyper_parameters['max_ngram_size'] = args.max_ngram_size hyper_parameters['use_bert'] = args.use_bert hyper_parameters['use_zen'] = args.use_zen hyper_parameters['do_lower_case'] = args.do_lower_case hyper_parameters['use_memory'] = args.use_memory hyper_parameters['use_party'] = args.use_party hyper_parameters['use_department'] = args.use_department hyper_parameters['use_disease'] = args.use_disease hyper_parameters['utterance_encoder'] = args.utterance_encoder hyper_parameters['decoder'] = args.decoder hyper_parameters['lstm_hidden_size'] = args.lstm_hidden_size hyper_parameters['max_dialog_length'] = args.max_dialog_length return hyper_parameters @classmethod def load_model(cls, model_path): label2id = load_json(os.path.join(model_path, 'label2id.json')) hpara = load_json(os.path.join(model_path, 'hpara.json')) department2id_path = os.path.join(model_path, 'department2id.json') department2id = load_json(department2id_path) if os.path.exists( department2id_path) else None word2id_path = os.path.join(model_path, 'word2id.json') word2id = load_json(word2id_path) if os.path.exists( word2id_path) else None disease2id_path = os.path.join(model_path, 'disease2id.json') disease2id = load_json(disease2id_path) if os.path.exists( disease2id_path) else None res = cls(model_path=model_path, label2id=label2id, hpara=hpara, department2id=department2id, word2id=word2id, disease2id=disease2id) res.load_state_dict( torch.load(os.path.join(model_path, 'pytorch_model.bin'))) return res def save_model(self, output_dir, vocab_dir): output_model_path = os.path.join(output_dir, 'pytorch_model.bin') torch.save(self.state_dict(), output_model_path) label_map_file = os.path.join(output_dir, 'label2id.json') if not os.path.exists(label_map_file): save_json(label_map_file, self.label2id) save_json(os.path.join(output_dir, 'hpara.json'), self.hpara) if self.department2id is not None: save_json(os.path.join(output_dir, 'department2id.json'), self.department2id) if self.word2id is not None: save_json(os.path.join(output_dir, 'word2id.json'), self.word2id) if self.disease2id is not None: save_json(os.path.join(output_dir, 'disease2id.json'), self.disease2id) output_config_file = os.path.join(output_dir, 'config.json') with open(output_config_file, "w", encoding='utf-8') as writer: if self.bert: writer.write(self.bert.config.to_json_string()) elif self.zen: writer.write(self.zen.config.to_json_string()) else: raise ValueError() output_bert_config_file = os.path.join(output_dir, 'bert_config.json') command = 'cp ' + str(output_config_file) + ' ' + str( output_bert_config_file) subprocess.run(command, shell=True) if self.bert or self.zen: vocab_name = 'vocab.txt' else: raise ValueError() vocab_path = os.path.join(vocab_dir, vocab_name) command = 'cp ' + str(vocab_path) + ' ' + str( os.path.join(output_dir, vocab_name)) subprocess.run(command, shell=True) if self.zen: ngram_name = 'ngram.txt' ngram_path = os.path.join(vocab_dir, ngram_name) command = 'cp ' + str(ngram_path) + ' ' + str( os.path.join(output_dir, ngram_name)) subprocess.run(command, shell=True) @staticmethod def data2example(data, flag=''): examples = [] for i, (utterance, label, party, summary, max_utterance_len, party_mask, department, disease) in enumerate(data): guid = "%s-%s" % (flag, i) text_a = utterance text_b = None examples.append( InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, party=party, summary=summary, max_utterance_len=max_utterance_len, party_mask=party_mask, department=department, disease=disease)) return examples def convert_examples_to_features(self, examples): features = [] tokenizer = self.zen_tokenizer if self.zen_tokenizer is not None else self.bert_tokenizer # -------- max ngram size -------- max_utterance_length = min( int( max([example.max_utterance_len for example in examples]) * 1.1 + 2), self.max_seq_length) max_seq_length = max_utterance_length max_dialog_length = min( max(max([len(example.text_a) for example in examples]), 1), self.max_dialog_length) # -------- max ngram size -------- for (ex_index, example) in enumerate(examples): valid = [[] for _ in range(max_dialog_length)] tokens = [[] for _ in range(max_dialog_length)] segment_ids = [[] for _ in range(max_dialog_length)] input_ids = [[] for _ in range(max_dialog_length)] input_mask = [[] for _ in range(max_dialog_length)] input_id_len = [1 for _ in range(max_dialog_length)] party_mask = [[] for _ in range(max_dialog_length)] for i in range(max_dialog_length): if i < len(example.text_a): utterance = example.text_a[i] party = example.party[i] if party == 'P': party_mask[i] = example.party_mask['P'] elif party == 'D': party_mask[i] = example.party_mask['D'] else: raise ValueError() if len(party_mask[i]) > max_dialog_length: party_mask[i] = party_mask[i][:max_dialog_length] while len(party_mask[i]) < max_dialog_length: party_mask[i].append(0) for word in utterance: token = tokenizer.tokenize(word) tokens[i].extend(token) for m in range(len(token)): if m == 0: valid[i].append(1) else: valid[i].append(0) if len(tokens[i]) >= max_utterance_length - 1: tokens[i] = tokens[i][0:(max_utterance_length - 2)] valid[i] = valid[i][0:(max_utterance_length - 2)] ntokens = [] ntokens.append("[CLS]") segment_ids[i].append(0) valid[i].insert(0, 1) for token in tokens[i]: ntokens.append(token) segment_ids[i].append(0) ntokens.append("[SEP]") segment_ids[i].append(0) valid[i].append(1) # ntokens: ['[CLS]', '我' ... , '人', '[SEP]'] length: 5 + 2 # valid: [1, ..., 1] length 5 + 2 (前后加 1) # label_mask: [1, ..., 1] length 5 + 2 (前后加 1) # label_ids: [6, 5, 5, 2, 3, 4, 7] (前后加 [CLS] 和 [SEP] 的标签) length 5 + 2 # segment_id: [0, 0, ..., 0] length 7 input_ids[i] = tokenizer.convert_tokens_to_ids(ntokens) # input_ids: [1, 2, 3, .. , 7] length 7 for _ in range(len(input_ids[i])): input_mask[i].append(1) input_id_len[i] = len(input_ids[i]) while len(input_ids[i]) < max_utterance_length: input_ids[i].append(0) input_mask[i].append(0) segment_ids[i].append(0) valid[i].append(1) while len(party_mask[i]) < max_dialog_length: party_mask[i].append(0) assert len(input_ids[i]) == len(input_mask[i]) assert len(input_ids[i]) == len(segment_ids[i]) assert len(input_ids[i]) == len(valid[i]) assert len(input_ids) == max_dialog_length assert len(input_ids[-1]) == max_utterance_length labellist = example.label label_mask = [] label_ids = [] for label in labellist: label_id = self.label2id[ label] if label in self.label2id else self.label2id['<UNK>'] label_ids.append(label_id) label_mask.append(1) if len(label_ids) > max_dialog_length: label_ids = label_ids[:max_dialog_length] label_mask = label_mask[:max_dialog_length] while len(label_ids) < max_dialog_length: label_ids.append(0) label_mask.append(0) partylist = example.party if self.party2id is not None: party_ids = [] for party in partylist: party_ids.append(self.party2id[party]) if len(party_ids) > max_dialog_length: party_ids = party_ids[:max_dialog_length] while len(party_ids) < max_dialog_length: party_ids.append(0) else: party_ids = None if self.department2id is not None: department_ids = [] if example.department in self.department2id: department_id = self.department2id[example.department] else: department_id = self.department2id['<UNK>'] for _ in partylist: department_ids.append(department_id) if len(department_ids) > max_dialog_length: department_ids = department_ids[:max_dialog_length] while len(department_ids) < max_dialog_length: department_ids.append(0) else: department_ids = None if self.disease2id is not None: disease_ids = [] if example.disease in self.disease2id: disease_id = self.disease2id[example.disease] else: disease_id = self.disease2id['<UNK>'] for _ in partylist: disease_ids.append(disease_id) if len(disease_ids) > max_dialog_length: disease_ids = disease_ids[:max_dialog_length] while len(disease_ids) < max_dialog_length: disease_ids.append(0) else: disease_ids = None assert len(label_ids) == len(label_mask) assert len(label_ids) == max_dialog_length assert len(label_ids) == len(party_mask) assert len(label_ids) == len(party_mask[-1]) if self.zen_ngram_dict is not None: all_ngram_ids = [] all_ngram_positions_matrix = [] # all_ngram_lengths = [] # all_ngram_tuples = [] # all_ngram_seg_ids = [] # all_ngram_mask_array = [] for token_list in tokens: ngram_matches = [] # Filter the ngram segment from 2 to 7 to check whether there is a ngram for p in range(2, 8): for q in range(0, len(token_list) - p + 1): character_segment = token_list[q:q + p] # j is the starting position of the ngram # i is the length of the current ngram character_segment = tuple(character_segment) if character_segment in self.zen_ngram_dict.ngram_to_id_dict: ngram_index = self.zen_ngram_dict.ngram_to_id_dict[ character_segment] ngram_matches.append( [ngram_index, q, p, character_segment]) # random.shuffle(ngram_matches) ngram_matches = sorted(ngram_matches, key=lambda s: s[0]) max_ngram_in_seq_proportion = math.ceil( (len(token_list) / max_seq_length) * self.zen_ngram_dict.max_ngram_in_seq) if len(ngram_matches) > max_ngram_in_seq_proportion: ngram_matches = ngram_matches[: max_ngram_in_seq_proportion] ngram_ids = [ngram[0] for ngram in ngram_matches] ngram_positions = [ngram[1] for ngram in ngram_matches] ngram_lengths = [ngram[2] for ngram in ngram_matches] # ngram_tuples = [ngram[3] for ngram in ngram_matches] # ngram_seg_ids = [0 if position < (len(tokens) + 2) else 1 for position in ngram_positions] ngram_mask_array = np.zeros( self.zen_ngram_dict.max_ngram_in_seq, dtype=np.bool) ngram_mask_array[:len(ngram_ids)] = 1 # record the masked positions ngram_positions_matrix = np.zeros( shape=(max_seq_length, self.zen_ngram_dict.max_ngram_in_seq), dtype=np.int32) for i in range(len(ngram_ids)): ngram_positions_matrix[ ngram_positions[i]:ngram_positions[i] + ngram_lengths[i], i] = 1.0 # Zero-pad up to the max ngram in seq length. padding = [0] * (self.zen_ngram_dict.max_ngram_in_seq - len(ngram_ids)) ngram_ids += padding # ngram_lengths += padding # ngram_seg_ids += padding all_ngram_ids.append(ngram_ids) all_ngram_positions_matrix.append(ngram_positions_matrix) # all_ngram_lengths.append(ngram_lengths) # all_ngram_tuples.append(ngram_tuples) # all_ngram_seg_ids.append(ngram_seg_ids) # all_ngram_mask_array.append(ngram_mask_array) while len(all_ngram_ids) < max_dialog_length: all_ngram_ids.append([0] * self.zen_ngram_dict.max_ngram_in_seq) all_ngram_positions_matrix.append( np.zeros(shape=(max_seq_length, self.zen_ngram_dict.max_ngram_in_seq), dtype=np.int32)) else: all_ngram_ids = None all_ngram_positions_matrix = None # all_ngram_lengths = None # all_ngram_tuples = None # all_ngram_seg_ids = None # all_ngram_mask_array = None features.append( InputFeatures( input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, label_id=label_ids, valid_ids=valid, label_mask=label_mask, input_id_len=input_id_len, party_mask=party_mask, party=party_ids, department=department_ids, disease=disease_ids, ngram_ids=all_ngram_ids, ngram_positions=all_ngram_positions_matrix, # ngram_lengths=all_ngram_lengths, # ngram_tuples=all_ngram_tuples, # ngram_seg_ids=all_ngram_seg_ids, # ngram_masks=all_ngram_mask_array )) return features def feature2input(self, device, feature): all_input_ids = torch.tensor([f.input_ids for f in feature], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in feature], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in feature], dtype=torch.long) all_label_ids = torch.tensor([f.label_id for f in feature], dtype=torch.long) all_valid_ids = torch.tensor([f.valid_ids for f in feature], dtype=torch.long) all_lmask_ids = torch.tensor([f.label_mask for f in feature], dtype=torch.long) input_ids = all_input_ids.to(device) input_mask = all_input_mask.to(device) segment_ids = all_segment_ids.to(device) label_ids = all_label_ids.to(device) valid_ids = all_valid_ids.to(device) l_mask = all_lmask_ids.to(device) all_lmask = torch.tensor([f.label_mask for f in feature], dtype=torch.float) lmask = all_lmask.to(device) if self.memory is not None: all_party_mask = torch.tensor([f.party_mask for f in feature], dtype=torch.float) party_mask = all_party_mask.to(device) else: party_mask = None if self.use_party: all_party_ids = torch.tensor([f.party for f in feature], dtype=torch.long) party_ids = all_party_ids.to(device) else: party_ids = None if self.use_department: all_department_ids = torch.tensor([f.department for f in feature], dtype=torch.long) department_ids = all_department_ids.to(device) else: department_ids = None if self.use_disease: all_disease_ids = torch.tensor([f.disease for f in feature], dtype=torch.long) disease_ids = all_disease_ids.to(device) else: disease_ids = None if self.zen is not None: all_ngram_ids = torch.tensor([f.ngram_ids for f in feature], dtype=torch.long) all_ngram_positions = torch.tensor( [f.ngram_positions for f in feature], dtype=torch.long) # all_ngram_lengths = torch.tensor([f.ngram_lengths for f in train_features], dtype=torch.long) # all_ngram_seg_ids = torch.tensor([f.ngram_seg_ids for f in train_features], dtype=torch.long) # all_ngram_masks = torch.tensor([f.ngram_masks for f in train_features], dtype=torch.long) ngram_ids = all_ngram_ids.to(device) ngram_positions = all_ngram_positions.to(device) else: ngram_ids = None ngram_positions = None return input_ids, input_mask, l_mask, label_ids, ngram_ids, ngram_positions, segment_ids, valid_ids, \ lmask, party_mask, party_ids, department_ids, disease_ids