class SeqModel(nn.Module): def __init__(self, data): super(SeqModel, self).__init__() self.data = data self.use_crf = data.use_crf print("build network...") print("word feature extractor: ", data.word_feature_extractor) self.gpu = data.HP_gpu self.average_batch = data.average_batch_loss # opinion 和 evidence 分开抽 label_size = data.label_alphabet_size self.word_hidden = WordSequence(data) if self.use_crf: self.word_crf = CRF(label_size, batch_first=True) if self.gpu: self.word_crf = self.word_crf.cuda() def neg_log_likelihood_loss(self, word_inputs, word_seq_lengths, batch_label, mask, input_label_seq_tensor): lstm_outs = self.word_hidden(word_inputs, word_seq_lengths, input_label_seq_tensor) # lstm_outs(batch_size,sentence_length,tag_size) batch_size = word_inputs.size(0) if self.use_crf: mask = mask.byte() loss = (-self.word_crf(lstm_outs, batch_label, mask)) tag_seq = self.word_crf.decode(lstm_outs, mask) else: loss_function = nn.NLLLoss() seq_len = lstm_outs.size(1) lstm_outs = lstm_outs.view(batch_size * seq_len, -1) score = F.log_softmax(lstm_outs, 1) loss = loss_function( score, batch_label.contiguous().view(batch_size * seq_len)) _, tag_seq = torch.max(score, 1) tag_seq = tag_seq.view(batch_size, seq_len) return loss, tag_seq def evaluate(self, word_inputs, word_seq_lengths, mask, input_label_seq_tensor): lstm_outs = self.word_hidden(word_inputs, word_seq_lengths, input_label_seq_tensor) if self.use_crf: mask = mask.byte() tag_seq = self.word_crf.decode(lstm_outs, mask) else: batch_size = word_inputs.size(0) seq_len = lstm_outs.size(1) lstm_outs = lstm_outs.view(batch_size * seq_len, -1) _, tag_seq = torch.max(lstm_outs, 1) tag_seq = mask.long() * tag_seq.view(batch_size, seq_len) return tag_seq def forward(self, word_inputs, word_seq_lengths, mask, input_label_seq_tensor): return self.evaluate(word_inputs, word_seq_lengths, mask, input_label_seq_tensor)
class GazLSTM(nn.Module): def __init__(self, data): super(GazLSTM, self).__init__() self.gpu = data.HP_gpu self.use_biword = data.use_bigram self.hidden_dim = data.HP_hidden_dim self.gaz_alphabet = data.gaz_alphabet self.gaz_emb_dim = data.gaz_emb_dim self.word_emb_dim = data.word_emb_dim self.biword_emb_dim = data.biword_emb_dim self.use_char = data.HP_use_char self.bilstm_flag = data.HP_bilstm self.lstm_layer = data.HP_lstm_layer self.use_count = data.HP_use_count self.num_layer = data.HP_num_layer self.model_type = data.model_type self.use_bert = data.use_bert # self.use_gazcount = data.use_gazcount #设置是否使用词典 self.use_dictionary = data.use_dictionary self.simi_dic_emb = data.simi_dic_emb self.simi_dic_dim = data.simi_dic_dim scale = np.sqrt(3.0 / self.gaz_emb_dim) data.pretrain_gaz_embedding[0, :] = np.random.uniform( -scale, scale, [1, self.gaz_emb_dim]) if self.use_char: scale = np.sqrt(3.0 / self.word_emb_dim) data.pretrain_word_embedding[0, :] = np.random.uniform( -scale, scale, [1, self.word_emb_dim]) self.gaz_embedding = nn.Embedding(data.gaz_alphabet.size(), self.gaz_emb_dim) #初始化gaz随机矩阵 self.word_embedding = nn.Embedding(data.word_alphabet.size(), self.word_emb_dim) #初始化word随机矩阵 if self.use_biword: self.biword_embedding = nn.Embedding(data.biword_alphabet.size(), self.biword_emb_dim) if data.pretrain_gaz_embedding is not None: self.gaz_embedding.weight.data.copy_( torch.from_numpy(data.pretrain_gaz_embedding) ) #将data.pretrain_gaz_embedding的值拷贝到gaz_embedding中 else: self.gaz_embedding.weight.data.copy_( torch.from_numpy( self.random_embedding(data.gaz_alphabet.size(), self.gaz_emb_dim))) if data.pretrain_word_embedding is not None: self.word_embedding.weight.data.copy_( torch.from_numpy(data.pretrain_word_embedding)) else: self.word_embedding.weight.data.copy_( torch.from_numpy( self.random_embedding(data.word_alphabet.size(), self.word_emb_dim))) if self.use_biword: if data.pretrain_biword_embedding is not None: self.biword_embedding.weight.data.copy_( torch.from_numpy(data.pretrain_biword_embedding)) else: self.biword_embedding.weight.data.copy_( torch.from_numpy( self.random_embedding(data.biword_alphabet.size(), self.word_emb_dim))) use_gazcount = True #字符的特征纬度 char_feature_dim = self.word_emb_dim + 4 * self.gaz_emb_dim if self.use_dictionary: if use_gazcount: char_feature_dim += self.simi_dic_dim else: char_feature_dim = self.word_emb_dim #+ self.simi_dic_dim if self.use_biword: char_feature_dim += self.biword_emb_dim if self.use_bert: char_feature_dim = char_feature_dim + 768 ## lstm model if self.model_type == 'lstm': lstm_hidden = self.hidden_dim if self.bilstm_flag: self.hidden_dim *= 2 self.NERmodel = NERmodel(model_type='lstm', input_dim=char_feature_dim, hidden_dim=lstm_hidden, num_layer=self.lstm_layer, biflag=self.bilstm_flag) ## cnn model if self.model_type == 'cnn': self.NERmodel = NERmodel(model_type='cnn', input_dim=char_feature_dim, hidden_dim=self.hidden_dim, num_layer=self.num_layer, dropout=data.HP_dropout, gpu=self.gpu) ## attention model if self.model_type == 'transformer': self.NERmodel = NERmodel(model_type='transformer', input_dim=char_feature_dim, hidden_dim=self.hidden_dim, num_layer=self.num_layer, dropout=data.HP_dropout) self.drop = nn.Dropout(p=data.HP_dropout) #按照0.5的概率改为零 self.hidden2tag = nn.Linear(self.hidden_dim, data.label_alphabet_size + 2) self.crf = CRF(data.label_alphabet_size, self.gpu) if self.use_bert: self.bert_encoder = BertModel.from_pretrained('bert-base-chinese') for p in self.bert_encoder.parameters(): p.requires_grad = False if self.gpu: self.gaz_embedding = self.gaz_embedding.cuda() self.word_embedding = self.word_embedding.cuda() if self.use_biword: self.biword_embedding = self.biword_embedding.cuda() self.NERmodel = self.NERmodel.cuda() self.hidden2tag = self.hidden2tag.cuda() self.crf = self.crf.cuda() if self.use_bert: self.bert_encoder = self.bert_encoder.cuda() #获取嵌入的函数 def get_tags(self, gaz_list, word_inputs, biword_inputs, layer_gaz, gaz_count, gaz_chars, gaz_mask_input, gazchar_mask_input, mask, word_seq_lengths, batch_bert, bert_mask, simi_value): use_gazcount = True batch_size = word_inputs.size()[0] seq_len = word_inputs.size()[1] max_gaz_num = layer_gaz.size(-1) gaz_match = [] word_embs = self.word_embedding(word_inputs) if self.use_biword: biword_embs = self.biword_embedding(biword_inputs) word_embs = torch.cat([word_embs, biword_embs], dim=-1) if self.model_type != 'transformer': word_inputs_d = self.drop(word_embs) #(b,l,we) print(type(word_inputs_d)) else: word_inputs_d = word_embs if self.use_char: gazchar_embeds = self.word_embedding(gaz_chars) gazchar_mask = gazchar_mask_input.unsqueeze(-1).repeat( 1, 1, 1, 1, 1, self.word_emb_dim) gazchar_embeds = gazchar_embeds.data.masked_fill_( gazchar_mask.data, 0) #(b,l,4,gl,cl,ce) # gazchar_mask_input:(b,l,4,gl,cl) gaz_charnum = (gazchar_mask_input == 0).sum( dim=-1, keepdim=True).float() #(b,l,4,gl,1) gaz_charnum = gaz_charnum + (gaz_charnum == 0).float() gaz_embeds = gazchar_embeds.sum(-2) / gaz_charnum #(b,l,4,gl,ce) if self.model_type != 'transformer': gaz_embeds = self.drop(gaz_embeds) else: gaz_embeds = gaz_embeds else: #use gaz embedding gaz_embeds = self.gaz_embedding(layer_gaz) if self.model_type != 'transformer': gaz_embeds_d = self.drop(gaz_embeds) else: gaz_embeds_d = gaz_embeds gaz_mask = gaz_mask_input.unsqueeze(-1).repeat( 1, 1, 1, 1, self.gaz_emb_dim) gaz_embeds = gaz_embeds_d.data.masked_fill_( gaz_mask.data, 0) #(b,l,4,g,ge) ge:gaz_embed_dim if self.use_count: count_sum = torch.sum(gaz_count, dim=3, keepdim=True) #(b,l,4,gn) count_sum = torch.sum(count_sum, dim=2, keepdim=True) #(b,l,1,1) weights = gaz_count.div(count_sum) #(b,l,4,g) weights = weights * 4 weights = weights.unsqueeze(-1) gaz_embeds = weights * gaz_embeds #(b,l,4,g,e) gaz_embeds = torch.sum(gaz_embeds, dim=3) #(b,l,4,e) else: gaz_num = (gaz_mask_input == 0).sum( dim=-1, keepdim=True).float() #(b,l,4,1) gaz_embeds = gaz_embeds.sum(-2) / gaz_num #(b,l,4,ge)/(b,l,4,1) gaz_embeds_cat = gaz_embeds.view(batch_size, seq_len, -1) #(b,l,4*ge) print(type(gaz_embeds_cat)) if self.use_dictionary: #拼接词典的向量 simi_embeds = [] for key in simi_value: for value in key: simi = [value for i in range(self.simi_dic_dim)] simi_embeds.append(simi) print(simi_embeds) simi_embeds = torch.Tensor(simi_embeds) simi_embeds = simi_embeds.cuda() print(simi_embeds) simi_embeds_cat = simi_embeds.view(batch_size, seq_len, -1) self.simi_dic_emb = simi_embeds if use_gazcount: word_input_cat = torch.cat( [word_inputs_d, gaz_embeds_cat, simi_embeds_cat], dim=-1) else: #word_input_cat = torch.cat([word_inputs_d, simi_embeds_cat],dim = -1) word_input_cat = torch.cat([word_inputs_d], dim=-1) else: word_input_cat = torch.cat([word_inputs_d, gaz_embeds_cat], dim=-1) #(b,l,we+4*ge) #print(len(word_input_cat)) # if only_char: # word_input_cat= torch.cat() ### cat bert feature if self.use_bert: seg_id = torch.zeros(bert_mask.size()).long().cuda() outputs = self.bert_encoder(batch_bert, bert_mask, seg_id) outputs = outputs[0][:, 1:-1, :] word_input_cat = torch.cat([word_input_cat, outputs], dim=-1) feature_out_d = self.NERmodel(word_input_cat) tags = self.hidden2tag(feature_out_d) return tags, gaz_match def neg_log_likelihood_loss(self, gaz_list, word_inputs, biword_inputs, word_seq_lengths, layer_gaz, gaz_count, gaz_chars, gaz_mask, gazchar_mask, mask, batch_label, batch_bert, bert_mask, simi_value): tags, _ = self.get_tags(gaz_list, word_inputs, biword_inputs, layer_gaz, gaz_count, gaz_chars, gaz_mask, gazchar_mask, mask, word_seq_lengths, batch_bert, bert_mask, simi_value) total_loss = self.crf.neg_log_likelihood_loss(tags, mask, batch_label) scores, tag_seq = self.crf._viterbi_decode(tags, mask) return total_loss, tag_seq def forward(self, gaz_list, word_inputs, biword_inputs, word_seq_lengths, layer_gaz, gaz_count, gaz_chars, gaz_mask, gazchar_mask, mask, batch_bert, bert_mask, simi_value): tags, gaz_match = self.get_tags(gaz_list, word_inputs, biword_inputs, layer_gaz, gaz_count, gaz_chars, gaz_mask, gazchar_mask, mask, word_seq_lengths, batch_bert, bert_mask, simi_value) scores, tag_seq = self.crf._viterbi_decode(tags, mask) return tag_seq, gaz_match
class GazLSTM(nn.Module): def __init__(self, data): super(GazLSTM, self).__init__() self.gpu = data.HP_gpu self.use_biword = data.use_bigram self.hidden_dim = data.HP_hidden_dim self.gaz_alphabet = data.gaz_alphabet self.gaz_emb_dim = data.gaz_emb_dim self.word_emb_dim = data.word_emb_dim self.biword_emb_dim = data.biword_emb_dim self.use_char = data.HP_use_char self.bilstm_flag = data.HP_bilstm self.lstm_layer = data.HP_lstm_layer self.use_count = data.HP_use_count self.num_layer = data.HP_num_layer self.model_type = data.model_type scale = np.sqrt(3.0 / self.gaz_emb_dim) data.pretrain_gaz_embedding[0, :] = np.random.uniform( -scale, scale, [1, self.gaz_emb_dim]) if self.use_char: scale = np.sqrt(3.0 / self.word_emb_dim) data.pretrain_word_embedding[0, :] = np.random.uniform( -scale, scale, [1, self.word_emb_dim]) self.gaz_embedding = nn.Embedding(data.gaz_alphabet.size(), self.gaz_emb_dim) self.word_embedding = nn.Embedding(data.word_alphabet.size(), self.word_emb_dim) if self.use_biword: self.biword_embedding = nn.Embedding(data.biword_alphabet.size(), self.biword_emb_dim) if data.pretrain_gaz_embedding is not None: self.gaz_embedding.weight.data.copy_( torch.from_numpy(data.pretrain_gaz_embedding)) else: self.gaz_embedding.weight.data.copy_( torch.from_numpy( self.random_embedding(data.gaz_alphabet.size(), self.gaz_emb_dim))) if data.pretrain_word_embedding is not None: self.word_embedding.weight.data.copy_( torch.from_numpy(data.pretrain_word_embedding)) else: self.word_embedding.weight.data.copy_( torch.from_numpy( self.random_embedding(data.word_alphabet.size(), self.word_emb_dim))) if self.use_biword: if data.pretrain_biword_embedding is not None: self.biword_embedding.weight.data.copy_( torch.from_numpy(data.pretrain_biword_embedding)) else: self.biword_embedding.weight.data.copy_( torch.from_numpy( self.random_embedding(data.biword_alphabet.size(), self.word_emb_dim))) char_feature_dim = self.word_emb_dim + 4 * self.gaz_emb_dim if self.use_biword: char_feature_dim += self.biword_emb_dim ## lstm model if self.model_type == 'lstm': lstm_hidden = self.hidden_dim if self.bilstm_flag: self.hidden_dim *= 2 self.NERmodel = NERmodel(model_type='lstm', input_dim=char_feature_dim, hidden_dim=lstm_hidden, num_layer=self.lstm_layer, biflag=self.bilstm_flag) ## cnn model if self.model_type == 'cnn': self.NERmodel = NERmodel(model_type='cnn', input_dim=char_feature_dim, hidden_dim=self.hidden_dim, num_layer=self.num_layer, dropout=data.HP_dropout, gpu=self.gpu) ## attention model if self.model_type == 'transformer': self.NERmodel = NERmodel(model_type='transformer', input_dim=char_feature_dim, hidden_dim=self.hidden_dim, num_layer=self.num_layer, dropout=data.HP_dropout) self.drop = nn.Dropout(p=data.HP_dropout) self.hidden2tag = nn.Linear(self.hidden_dim, data.label_alphabet_size + 2) self.crf = CRF(data.label_alphabet_size, self.gpu) if self.gpu: #self.drop = self.drop.cuda() self.gaz_embedding = self.gaz_embedding.cuda() self.word_embedding = self.word_embedding.cuda() if self.use_biword: self.biword_embedding = self.biword_embedding.cuda() self.NERmodel = self.NERmodel.cuda() self.hidden2tag = self.hidden2tag.cuda() self.crf = self.crf.cuda() def get_tags(self, gaz_list, word_inputs, biword_inputs, layer_gaz, gaz_count, gaz_chars, gaz_mask_input, gazchar_mask_input, mask): batch_size = word_inputs.size()[0] seq_len = word_inputs.size()[1] max_gaz_num = layer_gaz.size(-1) gaz_match = [] word_embs = self.word_embedding(word_inputs) if self.use_biword: biword_embs = self.biword_embedding(biword_inputs) word_embs = torch.cat([word_embs, biword_embs], dim=-1) if self.model_type != 'transformer': word_inputs_d = self.drop(word_embs) #(b,l,we) else: word_inputs_d = word_embs if self.use_char: gazchar_embeds = self.word_embedding(gaz_chars) gazchar_mask = gazchar_mask_input.unsqueeze(-1).repeat( 1, 1, 1, 1, 1, self.word_emb_dim) gazchar_embeds = gazchar_embeds.data.masked_fill_( gazchar_mask.data, 0) #(b,l,4,gl,cl,ce) # gazchar_mask_input:(b,l,4,gl,cl) gaz_charnum = (gazchar_mask_input == 0).sum( dim=-1, keepdim=True).float() #(b,l,4,gl,1) gaz_charnum = gaz_charnum + (gaz_charnum == 0).float() gaz_embeds = gazchar_embeds.sum(-2) / gaz_charnum #(b,l,4,gl,ce) if self.model_type != 'transformer': gaz_embeds = self.drop(gaz_embeds) else: gaz_embeds = gaz_embeds else: #use gaz embedding gaz_embeds = self.gaz_embedding(layer_gaz) if self.model_type != 'transformer': gaz_embeds_d = self.drop(gaz_embeds) else: gaz_embeds_d = gaz_embeds gaz_mask = gaz_mask_input.unsqueeze(-1).repeat( 1, 1, 1, 1, self.gaz_emb_dim) gaz_embeds = gaz_embeds_d.data.masked_fill_( gaz_mask.data, 0) #(b,l,4,g,ge) ge:gaz_embed_dim if self.use_count: count_sum = torch.sum(gaz_count, dim=3, keepdim=True) #(b,l,4,gn) count_sum = torch.sum(count_sum, dim=2, keepdim=True) #(b,l,1,1) weights = gaz_count.div(count_sum) #(b,l,4,g) weights = weights * 4 weights = weights.unsqueeze(-1) gaz_embeds = weights * gaz_embeds #(b,l,4,g,e) gaz_embeds = torch.sum(gaz_embeds, dim=3) #(b,l,4,e) else: gaz_num = (gaz_mask_input == 0).sum( dim=-1, keepdim=True).float() #(b,l,4,1) gaz_embeds = gaz_embeds.sum(-2) / gaz_num #(b,l,4,ge)/(b,l,4,1) gaz_embeds_cat = gaz_embeds.view(batch_size, seq_len, -1) #(b,l,4*ge) word_input_cat = torch.cat([word_inputs_d, gaz_embeds_cat], dim=-1) #(b,l,we+4*ge) feature_out_d = self.NERmodel(word_input_cat) tags = self.hidden2tag(feature_out_d) return tags, gaz_match def neg_log_likelihood_loss(self, gaz_list, word_inputs, biword_inputs, word_seq_lengths, layer_gaz, gaz_count, gaz_chars, gaz_mask, gazchar_mask, mask, batch_label): tags, _ = self.get_tags(gaz_list, word_inputs, biword_inputs, layer_gaz, gaz_count, gaz_chars, gaz_mask, gazchar_mask, mask) total_loss = self.crf.neg_log_likelihood_loss(tags, mask, batch_label) scores, tag_seq = self.crf._viterbi_decode(tags, mask) return total_loss, tag_seq def forward(self, gaz_list, word_inputs, biword_inputs, word_seq_lengths, layer_gaz, gaz_count, gaz_chars, gaz_mask, gazchar_mask, mask): tags, gaz_match = self.get_tags(gaz_list, word_inputs, biword_inputs, layer_gaz, gaz_count, gaz_chars, gaz_mask, gazchar_mask, mask) scores, tag_seq = self.crf._viterbi_decode(tags, mask) return tag_seq, gaz_match
class JointModel(nn.Module): def __init__(self, data): super(JointModel, self).__init__() self.data = data self.use_crf = data.use_crf logger.info("build network...") logger.info("word feature extractor: %s" % data.word_feature_extractor) logger.info("use_cuda: %s" % data.HP_gpu) self.gpu = data.HP_gpu logger.info("use_crf: %s" % data.use_crf) self.average_batch = data.average_batch_loss label_size = data.label_alphabet_size sentence_size = data.sentence_alphabet_size self.word_hidden = JointSequence(data) if self.use_crf: self.word_crf = CRF(label_size, batch_first=True) self.sent_crf = CRF(sentence_size, batch_first=True) if self.gpu: self.word_crf = self.word_crf.cuda() self.sent_crf = self.sent_crf.cuda() def neg_log_likelihood_loss(self, word_inputs, word_tensor, word_seq_lengths, batch_label, batch_sent_type, mask, sent_mask, input_label_seq_tensor, input_sent_type_tensor, batch_word_recover, word_perm_idx, need_cat=True, need_embedding=True): words_outs, sent_out = self.word_hidden( word_inputs, word_tensor, word_seq_lengths, input_label_seq_tensor, input_sent_type_tensor, batch_word_recover, word_perm_idx, batch_sent_type, need_cat, need_embedding) batch_size = words_outs.size(0) seq_len = words_outs.size(1) if self.use_crf: # e_out(batch_size,sentence_length,tag_size) words_loss = (-self.word_crf(words_outs, batch_label, mask)) / ( len(word_seq_lengths) * seq_len) words_tag_seq = self.word_crf.decode(words_outs, mask) sent_total_loss = -self.sent_crf( sent_out, batch_sent_type[batch_word_recover].view( batch_size, 1), sent_mask.view(batch_size, 1).byte()) / len(sent_mask) sent_tag_seq = self.sent_crf.decode( sent_out, sent_mask.view(batch_size, 1).byte()) else: loss_function = nn.NLLLoss() words_outs = words_outs.view(batch_size * seq_len, -1) words_score = F.log_softmax(words_outs, 1) words_loss = loss_function( words_score, batch_label.contiguous().view(batch_size * seq_len)) _, words_tag_seq = torch.max(words_score, 1) words_tag_seq = words_tag_seq.view(batch_size, seq_len) sent_out = sent_out.view(batch_size, -1) sent_score = F.log_softmax(sent_out, 1) sent_total_loss = loss_function( sent_score, batch_sent_type[batch_word_recover].view(batch_size)) _, sent_tag_seq = torch.max(sent_score, 1) return words_loss, words_tag_seq, sent_total_loss, sent_tag_seq def evaluate(self, word_inputs, word_tensor, word_seq_lengths, batch_sent_type, mask, sent_mask, input_label_seq_tensor, input_sent_type_tensor, batch_word_recover, word_perm_idx, need_cat=True, need_embedding=True): words_out, sent_out = self.word_hidden( word_inputs, word_tensor, word_seq_lengths, input_label_seq_tensor, input_sent_type_tensor, batch_word_recover, word_perm_idx, batch_sent_type, need_cat, need_embedding) batch_size = words_out.size(0) seq_len = words_out.size(1) if self.use_crf: sent_tag_seq = self.sent_crf.decode( sent_out, sent_mask.view(batch_size, 1).byte()) # 由于sentence在预测分类时已经恢复了顺序,后面的word顺序还没有恢复,所以此时要继续打乱顺序 sent_tag_seq = torch.tensor(sent_tag_seq)[word_perm_idx] if self.gpu: sent_tag_seq = sent_tag_seq.cpu().data.numpy().tolist() else: sent_tag_seq = sent_tag_seq.data.numpy().tolist() words_tag_seq = self.word_crf.decode(words_out, mask) else: sent_out = sent_out.view(batch_size, -1) _, sent_tag_seq = torch.max(sent_out, 1) # 由于sentence在预测分类时已经恢复了顺序,后面的word顺序还没有恢复,所以此时要继续打乱顺序 sent_tag_seq = sent_tag_seq[word_perm_idx] words_out = words_out.view(batch_size * seq_len, -1) _, words_tag_seq = torch.max(words_out, 1) words_tag_seq = mask.long() * words_tag_seq.view( batch_size, seq_len) return words_tag_seq, sent_tag_seq def forward(self, word_inputs, word_tensor, word_seq_lengths, mask, sent_mask, input_label_seq_tensor, input_sent_type_tensor, batch_word_recover, word_perm_idx, need_cat=True, need_embedding=True): batch_size = word_tensor.size(0) seq_len = word_tensor.size(1) lstm_out, hidden, sent_out, label_embs = self.word_hidden.evaluate_sentence( word_inputs, word_tensor, word_seq_lengths, input_label_seq_tensor, input_sent_type_tensor, batch_word_recover, need_cat, need_embedding) lstm_out = torch.cat([ lstm_out, sent_out[word_perm_idx].expand( [lstm_out.size(0), lstm_out.size(1), sent_out.size(-1)]) ], -1) words_outs = self.word_hidden.evaluate_word(lstm_out, hidden, word_seq_lengths, label_embs) if self.use_crf: sent_tag_seq = self.sent_crf.decode( sent_out, sent_mask.view(batch_size, 1).byte()) # 由于sentence在预测分类时已经恢复了顺序,后面的word顺序还没有恢复,所以此时要继续打乱顺序 sent_tag_seq = torch.tensor(sent_tag_seq)[word_perm_idx] words_tag_seq = self.word_crf.decode(words_outs, mask) else: sent_out = sent_out.view(batch_size, -1) _, sent_tag_seq = torch.max(sent_out, 1) # 由于sentence在预测分类时已经恢复了顺序,后面的word顺序还没有恢复,所以此时要继续打乱顺序 sent_tag_seq = sent_tag_seq[word_perm_idx] words_outs = words_outs.view(batch_size * seq_len, -1) _, words_tag_seq = torch.max(words_outs, 1) words_tag_seq = mask.long() * words_tag_seq.view( batch_size, seq_len) return words_tag_seq, sent_tag_seq
class CNNmodel(nn.Module): def __init__(self, data): super(CNNmodel, self).__init__() self.gpu = data.HP_gpu self.use_biword = data.use_bigram self.use_posi = data.HP_use_posi self.hidden_dim = data.HP_hidden_dim self.gaz_alphabet = data.gaz_alphabet self.gaz_emb_dim = data.gaz_emb_dim self.word_emb_dim = data.word_emb_dim self.posi_emb_dim = data.posi_emb_dim self.biword_emb_dim = data.biword_emb_dim self.rethink_iter = data.HP_rethink_iter scale = np.sqrt(3.0 / self.gaz_emb_dim) data.pretrain_gaz_embedding[0, :] = np.random.uniform( -scale, scale, [1, self.gaz_emb_dim]) self.gaz_embedding = nn.Embedding(data.gaz_alphabet.size(), self.gaz_emb_dim) self.gaz_embedding.weight.data.copy_( torch.from_numpy(data.pretrain_gaz_embedding)) self.word_embedding = nn.Embedding(data.word_alphabet.size(), self.word_emb_dim) self.word_embedding.weight.data.copy_( torch.from_numpy(data.pretrain_word_embedding)) if data.HP_use_posi: data.posi_alphabet_size += 1 self.position_embedding = nn.Embedding.from_pretrained( get_sinusoid_encoding_table(data.posi_alphabet_size, self.posi_emb_dim), freeze=True) if self.use_biword: self.biword_embedding = nn.Embedding(data.biword_alphabet.size(), self.biword_emb_dim) self.biword_embedding.weight.data.copy_( torch.from_numpy(data.pretrain_biword_embedding)) self.drop = nn.Dropout(p=data.HP_dropout) self.num_layer = data.HP_num_layer input_dim = self.word_emb_dim if self.use_biword: input_dim += self.biword_emb_dim if self.use_posi: input_dim += self.posi_emb_dim self.cnn_layer0 = nn.Conv1d(input_dim, self.hidden_dim, kernel_size=1, padding=0) self.cnn_layers = [ nn.Conv1d(self.hidden_dim, self.hidden_dim, kernel_size=2, padding=0) for i in range(self.num_layer - 1) ] self.cnn_layers_back = [ nn.Conv1d(self.hidden_dim, self.hidden_dim, kernel_size=2, padding=0) for i in range(self.num_layer - 1) ] self.res_cnn_layers = [ nn.Conv1d(self.hidden_dim, self.hidden_dim, kernel_size=i + 2, padding=0) for i in range(1, self.num_layer - 1) ] self.res_cnn_layers_back = [ nn.Conv1d(self.hidden_dim, self.hidden_dim, kernel_size=i + 2, padding=0) for i in range(1, self.num_layer - 1) ] self.layer_gate = LayerGate(self.hidden_dim, self.gaz_emb_dim) self.global_gate = GlobalGate(self.hidden_dim) self.exper2gate = nn.Linear(self.hidden_dim, self.hidden_dim * 4) self.multiscale_layer = MultiscaleAttention(self.num_layer, data.HP_dropout) self.hidden2tag = nn.Linear(self.hidden_dim, data.label_alphabet_size + 2) self.crf = CRF(data.label_alphabet_size, self.gpu) if self.gpu: self.gaz_embedding = self.gaz_embedding.cuda() self.word_embedding = self.word_embedding.cuda() if self.use_posi: self.position_embedding = self.position_embedding.cuda() if self.use_biword: self.biword_embedding = self.biword_embedding.cuda() self.cnn_layer0 = self.cnn_layer0.cuda() self.multiscale_layer = self.multiscale_layer.cuda() self.hidden2tag = self.hidden2tag.cuda() self.layer_gate = self.layer_gate.cuda() self.global_gate = self.global_gate.cuda() self.crf = self.crf.cuda() for i in range(self.num_layer - 1): self.cnn_layers[i] = self.cnn_layers[i].cuda() self.cnn_layers_back[i] = self.cnn_layers_back[i].cuda() if i >= 1: self.res_cnn_layers[i - 1] = self.res_cnn_layers[i - 1].cuda() self.res_cnn_layers_back[i - 1] = self.res_cnn_layers_back[ i - 1].cuda() def get_tags(self, gaz_list, word_inputs, biword_inputs, layer_gaz, gaz_mask_input, mask): batch_size = word_inputs.size(0) seq_len = word_inputs.size(1) word_embs = self.word_embedding(word_inputs) if self.use_biword: biword_embs = self.biword_embedding(biword_inputs) word_embs = torch.cat([word_embs, biword_embs], dim=2) if self.use_posi: posi_inputs = torch.zeros(batch_size, seq_len).long() posi_inputs[:, :] = torch.LongTensor( [i + 1 for i in range(seq_len)]) if self.gpu: posi_inputs = posi_inputs.cuda() position_embs = self.position_embedding(posi_inputs) word_embs = torch.cat([word_embs, position_embs], dim=2) word_inputs_d = self.drop(word_embs) word_inputs_d = word_inputs_d.transpose(2, 1).contiguous() X_pre = self.cnn_layer0( word_inputs_d) #(batch_size,hidden_size,seq_len) X_pre = torch.tanh(X_pre) X_trans = X_pre.transpose(2, 1).contiguous() global_matrix0 = self.global_gate(X_trans) # G0 X_list = X_trans.unsqueeze( 2) #(batch_size,seq_len,num_layer,hidden_size) padding = torch.zeros(batch_size, self.hidden_dim, 1) if self.gpu: padding = padding.cuda() feed_back = None for iteration in range(self.rethink_iter): global_matrix = global_matrix0 X_pre = self.drop(X_pre) X_pre_padding = torch.cat( [X_pre, padding], dim=2) #(batch_size,hidden_size,seq_len+1) X_pre_padding_back = torch.cat([padding, X_pre], dim=2) for layer in range(self.num_layer - 1): X = self.cnn_layers[layer]( X_pre_padding) # X: (batch_size,hidden_size,seq_len) X = torch.tanh(X) X_back = self.cnn_layers_back[layer]( X_pre_padding_back) # X: (batch_size,hidden_size,seq_len) X_back = torch.tanh(X_back) if layer > 0: windowpad = torch.cat([padding for i in range(layer)], dim=2) X_pre_padding_w = torch.cat([X_pre, windowpad, padding], dim=2) X_res = self.res_cnn_layers[layer - 1](X_pre_padding_w) X_res = torch.tanh(X_res) X_pre_padding_w_back = torch.cat( [padding, windowpad, X_pre], dim=2) X_res_back = self.res_cnn_layers_back[layer - 1]( X_pre_padding_w_back) X_res_back = torch.tanh(X_res_back) layer_gaz_back = torch.zeros(batch_size, seq_len).long() if seq_len > layer + 1: layer_gaz_back[:, layer + 1:] = layer_gaz[:, :seq_len - layer - 1, layer] if self.gpu: layer_gaz_back = layer_gaz_back.cuda() gazs_embeds = self.gaz_embedding(layer_gaz[:, :, layer]) gazs_embeds_back = self.gaz_embedding(layer_gaz_back) mask_gaz = (mask == 0).unsqueeze(-1).repeat( 1, 1, self.gaz_emb_dim) gazs_embeds = gazs_embeds.masked_fill(mask_gaz, 0) gazs_embeds_back = gazs_embeds_back.masked_fill(mask_gaz, 0) gazs_embeds = self.drop(gazs_embeds) gazs_embeds_back = self.drop(gazs_embeds_back) if layer > 0: #res X_input = torch.cat([X, X_back, X_res, X_res_back], dim=-1).transpose( 2, 1).contiguous() #(b,4l,h) X, X_back, X_res, X_res_back = self.layer_gate( X_input, gazs_embeds, gazs_embeds_back, global_matrix, exper_input=feed_back, gaz_mask=None) X = X + X_back + X_res + X_res_back else: X_input = torch.cat([X, X_back, X, X_back], dim=-1).transpose( 2, 1).contiguous() #(b,4l,h) X, X_back, _, _ = self.layer_gate(X_input, gazs_embeds, gazs_embeds_back, global_matrix, exper_input=feed_back, gaz_mask=None) X = X + X_back global_matrix = self.global_gate(X, global_matrix) if iteration == self.rethink_iter - 1: X_list = torch.cat([X_list, X.unsqueeze(2)], dim=2) if layer == self.num_layer - 2: feed_back = X X = X.transpose(2, 1).contiguous() X_d = self.drop(X) X_pre_padding = torch.cat([X_d, padding], dim=2) #padding padding_back = torch.cat( [padding for _ in range(min(layer + 2, seq_len + 1))], dim=2) if seq_len > layer + 1: X_pre_padding_back = torch.cat( [padding_back, X_d[:, :, :seq_len - layer - 1]], dim=2) #(b,h,seqlen+1) else: X_pre_padding_back = padding_back X_attention = self.multiscale_layer(X_list) tags = self.hidden2tag(X_attention) #(b,l,t) return tags def neg_log_likelihood_loss(self, gaz_list, word_inputs, biword_inputs, word_seq_lengths, layer_gaz, gaz_mask, mask, batch_label): tags = self.get_tags(gaz_list, word_inputs, biword_inputs, layer_gaz, gaz_mask, mask) total_loss = self.crf.neg_log_likelihood_loss(tags, mask, batch_label) scores, tag_seq = self.crf._viterbi_decode(tags, mask) return total_loss, tag_seq def forward(self, gaz_list, word_inputs, biword_inputs, word_seq_lengths, layer_gaz, gaz_mask, mask): tags = self.get_tags(gaz_list, word_inputs, biword_inputs, layer_gaz, gaz_mask, mask) scores, tag_seq = self.crf._viterbi_decode(tags, mask) return tag_seq