def __init__(self, args, word_emb, ent_conf, spo_conf): print('mhs using only char2v+w2v mixed and word_emb is freeze ') super(ERENet, self).__init__() self.max_len = args.max_len self.word_emb = nn.Embedding.from_pretrained(torch.tensor(word_emb, dtype=torch.float32), freeze=True, padding_idx=0) self.char_emb = nn.Embedding(num_embeddings=args.char_vocab_size, embedding_dim=args.char_emb_size, padding_idx=0) self.word_convert_char = nn.Linear(args.word_emb_size, args.char_emb_size, bias=False) self.classes_num = len(spo_conf) self.first_sentence_encoder = SentenceEncoder(args, args.char_emb_size) # self.second_sentence_encoder = SentenceEncoder(args, args.hidden_size) # self.token_entity_emb = nn.Embedding(num_embeddings=2, embedding_dim=config.hidden_size, # padding_idx=0) self.encoder_layer = TransformerEncoderLayer(args.hidden_size * 2, nhead=3) self.transformer_encoder = TransformerEncoder(self.encoder_layer, num_layers=1) self.LayerNorm = ConditionalLayerNorm(args.hidden_size * 2, eps=1e-12) # self.subject_dense = nn.Linear(args.hidden_size * 2, 2) self.ent_emission = nn.Linear(args.hidden_size * 2, len(ent_conf)) self.ent_crf = CRF(len(ent_conf), batch_first=True) self.emission = nn.Linear(args.hidden_size * 2, len(spo_conf)) self.crf = CRF(len(spo_conf), batch_first=True) self.loss_fct = nn.BCEWithLogitsLoss(reduction='none')
def __init__(self, args, word_emb): super(ERENet, self).__init__() print('mhs with w2v') if args.activation.lower() == 'relu': self.activation = nn.ReLU() elif args.activation.lower() == 'tanh': self.activation = nn.Tanh() self.word_emb = nn.Embedding.from_pretrained(torch.tensor( word_emb, dtype=torch.float32), freeze=True, padding_idx=0) self.word_convert_char = nn.Linear(args.word_emb_size, args.char_emb_size, bias=False) self.char_emb = nn.Embedding(num_embeddings=args.char_vocab_size, embedding_dim=args.char_emb_size, padding_idx=0) self.rel_emb = nn.Embedding(num_embeddings=len(BAIDU_RELATION), embedding_dim=args.rel_emb_size) self.ent_emb = nn.Embedding(num_embeddings=len(BAIDU_ENTITY), embedding_dim=args.ent_emb_size) self.sentence_encoder = SentenceEncoder(args, args.char_emb_size) self.emission = nn.Linear(args.hidden_size * 2, len(BAIDU_ENTITY)) self.crf = CRF(len(BAIDU_ENTITY), batch_first=True) self.selection_u = nn.Linear(2 * args.hidden_size + args.ent_emb_size, args.rel_emb_size) self.selection_v = nn.Linear(2 * args.hidden_size + args.ent_emb_size, args.rel_emb_size) self.selection_uv = nn.Linear(2 * args.rel_emb_size, args.rel_emb_size)
class BertNER(nn.Module): def __init__(self, data): super(BertNER, self).__init__() self.gpu = data.HP_gpu self.use_bert = data.use_bert self.bertpath = data.bertpath char_feature_dim = 768 print('total char_feature_dim is {}'.format(char_feature_dim)) self.bert_encoder = BertModel.from_pretrained(self.bertpath) self.hidden2tag = nn.Linear(char_feature_dim, data.label_alphabet_size + 2) self.drop = nn.Dropout(p=data.HP_dropout) self.crf = CRF(data.label_alphabet_size, self.gpu) if self.gpu: self.bert_encoder = self.bert_encoder.cuda() self.hidden2tag = self.hidden2tag.cuda() self.crf = self.crf.cuda() def get_tags(self, batch_bert, bert_mask): seg_id = torch.zeros( bert_mask.size()).long().cuda() if self.gpu else torch.zeros( bert_mask.size()).long() outputs = self.bert_encoder(batch_bert, bert_mask, seg_id) outputs = outputs[0][:, 1:-1, :] tags = self.hidden2tag(outputs) return tags def neg_log_likelihood_loss(self, word_inputs, biword_inputs, word_seq_lengths, mask, batch_label, batch_bert, bert_mask): tags = self.get_tags(batch_bert, bert_mask) total_loss = self.crf.neg_log_likelihood_loss(tags, mask, batch_label) scores, tag_seq = self.crf._viterbi_decode(tags, mask) return total_loss, tag_seq def forward(self, word_inputs, biword_inputs, word_seq_lengths, mask, batch_bert, bert_mask): tags = self.get_tags(batch_bert, bert_mask) scores, tag_seq = self.crf._viterbi_decode(tags, mask) return tag_seq
def __init__(self, args, model_conf): super(NERNet, self).__init__() char_emb = model_conf['char_emb'] bichar_emb = model_conf['bichar_emb'] embed_size = args.char_emb_dim if char_emb is not None: # self.char_emb = nn.Embedding.from_pretrained(char_emb, freeze=False, padding_idx=0) self.char_emb = nn.Embedding(num_embeddings=char_emb.shape[0], embedding_dim=char_emb.shape[1], padding_idx=0, _weight=char_emb) self.char_emb.weight.requires_grad = True embed_size = char_emb.size()[1] else: vocab_size = len(model_conf['char_vocab']) self.char_emb = nn.Embedding(num_embeddings=vocab_size, embedding_dim=args.char_emb_dim, padding_idx=0) self.bichar_emb = None if bichar_emb is not None: # self.bichar_emb = nn.Embedding.from_pretrained(bichar_emb, freeze=False, padding_idx=0) self.bichar_emb = nn.Embedding(num_embeddings=bichar_emb.shape[0], embedding_dim=bichar_emb.shape[1], padding_idx=0, _weight=bichar_emb) self.bichar_emb.weight.requires_grad = True embed_size += bichar_emb.size()[1] self.drop = nn.Dropout(p=0.5) # self.sentence_encoder = SentenceEncoder(args, embed_size) self.sentence_encoder = nn.LSTM(embed_size, args.hidden_size, num_layers=1, batch_first=True, bidirectional=True) self.emission = nn.Linear(args.hidden_size * 2, len(model_conf['entity_type'])) self.crf = CRF(len(model_conf['entity_type']), batch_first=True)
def __init__(self, data): super(BertNER, self).__init__() self.gpu = data.HP_gpu self.use_bert = data.use_bert self.bertpath = data.bertpath char_feature_dim = 768 print('total char_feature_dim is {}'.format(char_feature_dim)) self.bert_encoder = BertModel.from_pretrained(self.bertpath) self.hidden2tag = nn.Linear(char_feature_dim, data.label_alphabet_size + 2) self.drop = nn.Dropout(p=data.HP_dropout) self.crf = CRF(data.label_alphabet_size, self.gpu) if self.gpu: self.bert_encoder = self.bert_encoder.cuda() self.hidden2tag = self.hidden2tag.cuda() self.crf = self.crf.cuda()
class GazLSTM(nn.Module): def __init__(self, data): super(GazLSTM, self).__init__() self.gpu = data.HP_gpu self.use_biword = data.use_bigram self.hidden_dim = data.HP_hidden_dim self.word_emb_dim = data.word_emb_dim self.biword_emb_dim = data.biword_emb_dim self.bilstm_flag = data.HP_bilstm self.lstm_layer = data.HP_lstm_layer self.num_layer = data.HP_num_layer self.model_type = data.model_type self.use_bert = data.use_bert self.device = data.device self.word_embedding = nn.Embedding(data.word_alphabet.size(), self.word_emb_dim, padding_idx=0) if data.pretrain_word_embedding is not None: self.word_embedding.weight.data.copy_( torch.from_numpy(data.pretrain_word_embedding)) if self.use_biword: self.biword_embedding = nn.Embedding(data.biword_alphabet.size(), self.biword_emb_dim, padding_idx=0) if data.pretrain_biword_embedding is not None: self.biword_embedding.weight.data.copy_( torch.from_numpy(data.pretrain_biword_embedding)) char_feature_dim = self.word_emb_dim if self.use_biword: char_feature_dim += self.biword_emb_dim if self.use_bert: char_feature_dim = char_feature_dim + 768 * 2 print('total char_feature_dim is {}'.format(char_feature_dim)) ## lstm model if self.model_type == 'lstm': lstm_hidden = self.hidden_dim if self.bilstm_flag: self.hidden_dim *= 2 self.NERmodel = NERmodel(model_type='lstm', input_dim=char_feature_dim, hidden_dim=lstm_hidden, num_layer=self.lstm_layer, biflag=self.bilstm_flag) self.hidden2tag = nn.Linear(self.hidden_dim, data.label_alphabet_size + 2) # ## cnn model # if self.model_type == 'cnn': # self.NERmodel = NERmodel(model_type='cnn', input_dim=char_feature_dim, hidden_dim=self.hidden_dim, # num_layer=self.num_layer, dropout=data.HP_dropout, gpu=self.gpu) # # ## attention model if self.model_type == 'transformer': self.NERmodel = NERmodel(model_type='transformer', input_dim=char_feature_dim, hidden_dim=self.hidden_dim, num_layer=self.num_layer, dropout=data.HP_dropout) self.hidden2tag = nn.Linear(480, data.label_alphabet_size + 2) self.drop = nn.Dropout(p=data.HP_dropout) self.crf = CRF(data.label_alphabet_size, self.gpu, self.device) if self.use_bert: self.bert_encoder_1 = BertModel.from_pretrained( 'transformer_cpt/bert/') self.bert_encoder_2 = BertModel.from_pretrained( 'transformer_cpt/chinese_roberta_wwm_ext_pytorch/') for p in self.bert_encoder_1.parameters(): p.requires_grad = False for p in self.bert_encoder_2.parameters(): p.requires_grad = False if self.gpu: self.word_embedding = self.word_embedding.cuda(self.device) if self.use_biword: self.biword_embedding = self.biword_embedding.cuda(self.device) self.NERmodel = self.NERmodel.cuda(self.device) self.hidden2tag = self.hidden2tag.cuda(self.device) self.crf = self.crf.cuda(self.device) if self.use_bert: self.bert_encoder_1 = self.bert_encoder_1.cuda(self.device) self.bert_encoder_2 = self.bert_encoder_2.cuda(self.device) def get_tags(self, word_inputs, biword_inputs, mask, word_seq_lengths, batch_bert, bert_mask): batch_size = word_inputs.size()[0] seq_len = word_inputs.size()[1] word_embs = self.word_embedding(word_inputs) if self.use_biword: biword_embs = self.biword_embedding(biword_inputs) word_embs = torch.cat([word_embs, biword_embs], dim=-1) if self.model_type != 'transformer': word_inputs_d = self.drop(word_embs) # (b,l,we) else: word_inputs_d = word_embs word_input_cat = torch.cat([word_inputs_d], dim=-1) # (b,l,we+4*ge) if self.use_bert: seg_id = torch.zeros(bert_mask.size()).long().cuda( self.device) if self.gpu else torch.zeros( bert_mask.size()).long() outputs_1 = self.bert_encoder_1(batch_bert, bert_mask, seg_id) outputs_1 = outputs_1[0][:, 1:-1, :] outputs_2 = self.bert_encoder_2(batch_bert, bert_mask, seg_id) outputs_2 = outputs_2[0][:, 1:-1, :] word_input_cat = torch.cat([word_input_cat, outputs_1, outputs_2], dim=-1) feature_out_d = self.NERmodel(word_input_cat, word_inputs.ne(0)) tags = self.hidden2tag(feature_out_d) return tags def neg_log_likelihood_loss(self, word_inputs, biword_inputs, word_seq_lengths, mask, batch_label, batch_bert, bert_mask): tags = self.get_tags(word_inputs, biword_inputs, mask, word_seq_lengths, batch_bert, bert_mask) total_loss = self.crf.neg_log_likelihood_loss(tags, mask, batch_label) scores, tag_seq = self.crf._viterbi_decode(tags, mask) return total_loss, tag_seq def forward(self, word_inputs, biword_inputs, word_seq_lengths, mask, batch_bert, bert_mask): tags = self.get_tags(word_inputs, biword_inputs, mask, word_seq_lengths, batch_bert, bert_mask) scores, tag_seq = self.crf._viterbi_decode(tags, mask) return tag_seq
def __init__(self, data): super(GazLSTM, self).__init__() self.gpu = data.HP_gpu self.use_biword = data.use_bigram self.hidden_dim = data.HP_hidden_dim self.word_emb_dim = data.word_emb_dim self.biword_emb_dim = data.biword_emb_dim self.bilstm_flag = data.HP_bilstm self.lstm_layer = data.HP_lstm_layer self.num_layer = data.HP_num_layer self.model_type = data.model_type self.use_bert = data.use_bert self.device = data.device self.word_embedding = nn.Embedding(data.word_alphabet.size(), self.word_emb_dim, padding_idx=0) if data.pretrain_word_embedding is not None: self.word_embedding.weight.data.copy_( torch.from_numpy(data.pretrain_word_embedding)) if self.use_biword: self.biword_embedding = nn.Embedding(data.biword_alphabet.size(), self.biword_emb_dim, padding_idx=0) if data.pretrain_biword_embedding is not None: self.biword_embedding.weight.data.copy_( torch.from_numpy(data.pretrain_biword_embedding)) char_feature_dim = self.word_emb_dim if self.use_biword: char_feature_dim += self.biword_emb_dim if self.use_bert: char_feature_dim = char_feature_dim + 768 * 2 print('total char_feature_dim is {}'.format(char_feature_dim)) ## lstm model if self.model_type == 'lstm': lstm_hidden = self.hidden_dim if self.bilstm_flag: self.hidden_dim *= 2 self.NERmodel = NERmodel(model_type='lstm', input_dim=char_feature_dim, hidden_dim=lstm_hidden, num_layer=self.lstm_layer, biflag=self.bilstm_flag) self.hidden2tag = nn.Linear(self.hidden_dim, data.label_alphabet_size + 2) # ## cnn model # if self.model_type == 'cnn': # self.NERmodel = NERmodel(model_type='cnn', input_dim=char_feature_dim, hidden_dim=self.hidden_dim, # num_layer=self.num_layer, dropout=data.HP_dropout, gpu=self.gpu) # # ## attention model if self.model_type == 'transformer': self.NERmodel = NERmodel(model_type='transformer', input_dim=char_feature_dim, hidden_dim=self.hidden_dim, num_layer=self.num_layer, dropout=data.HP_dropout) self.hidden2tag = nn.Linear(480, data.label_alphabet_size + 2) self.drop = nn.Dropout(p=data.HP_dropout) self.crf = CRF(data.label_alphabet_size, self.gpu, self.device) if self.use_bert: self.bert_encoder_1 = BertModel.from_pretrained( 'transformer_cpt/bert/') self.bert_encoder_2 = BertModel.from_pretrained( 'transformer_cpt/chinese_roberta_wwm_ext_pytorch/') for p in self.bert_encoder_1.parameters(): p.requires_grad = False for p in self.bert_encoder_2.parameters(): p.requires_grad = False if self.gpu: self.word_embedding = self.word_embedding.cuda(self.device) if self.use_biword: self.biword_embedding = self.biword_embedding.cuda(self.device) self.NERmodel = self.NERmodel.cuda(self.device) self.hidden2tag = self.hidden2tag.cuda(self.device) self.crf = self.crf.cuda(self.device) if self.use_bert: self.bert_encoder_1 = self.bert_encoder_1.cuda(self.device) self.bert_encoder_2 = self.bert_encoder_2.cuda(self.device)
class NERNet(nn.Module): """ NERNet : Lstm+CRF """ def __init__(self, args, model_conf): super(NERNet, self).__init__() char_emb = model_conf['char_emb'] bichar_emb = model_conf['bichar_emb'] embed_size = args.char_emb_dim if char_emb is not None: # self.char_emb = nn.Embedding.from_pretrained(char_emb, freeze=False, padding_idx=0) self.char_emb = nn.Embedding(num_embeddings=char_emb.shape[0], embedding_dim=char_emb.shape[1], padding_idx=0, _weight=char_emb) self.char_emb.weight.requires_grad = True embed_size = char_emb.size()[1] else: vocab_size = len(model_conf['char_vocab']) self.char_emb = nn.Embedding(num_embeddings=vocab_size, embedding_dim=args.char_emb_dim, padding_idx=0) self.bichar_emb = None if bichar_emb is not None: # self.bichar_emb = nn.Embedding.from_pretrained(bichar_emb, freeze=False, padding_idx=0) self.bichar_emb = nn.Embedding(num_embeddings=bichar_emb.shape[0], embedding_dim=bichar_emb.shape[1], padding_idx=0, _weight=bichar_emb) self.bichar_emb.weight.requires_grad = True embed_size += bichar_emb.size()[1] self.drop = nn.Dropout(p=0.5) # self.sentence_encoder = SentenceEncoder(args, embed_size) self.sentence_encoder = nn.LSTM(embed_size, args.hidden_size, num_layers=1, batch_first=True, bidirectional=True) self.emission = nn.Linear(args.hidden_size * 2, len(model_conf['entity_type'])) self.crf = CRF(len(model_conf['entity_type']), batch_first=True) def forward(self, char_id, bichar_id, label_id=None, is_eval=False): # use anti-mask for answers-locator mask = char_id.eq(0) chars = self.char_emb(char_id) if self.bichar_emb is not None: bichars = self.bichar_emb(bichar_id) chars = torch.cat([chars, bichars], dim=-1) chars = self.drop(chars) # sen_encoded = self.sentence_encoder(chars, mask) sen_encoded, _ = self.sentence_encoder(chars) sen_encoded = self.drop(sen_encoded) bio_mask = char_id != 0 emission = self.emission(sen_encoded) emission = F.log_softmax(emission, dim=-1) if not is_eval: crf_loss = -self.crf( emission, label_id, mask=bio_mask, reduction='mean') return crf_loss else: pred = self.crf.decode(emissions=emission, mask=bio_mask) # TODO:check max_len = char_id.size(1) temp_tag = copy.deepcopy(pred) for line in temp_tag: line.extend([0] * (max_len - len(line))) ent_pre = torch.tensor(temp_tag).to(emission.device) return ent_pre
def __init__(self, data): super(GazLSTM, self).__init__() self.gpu = data.HP_gpu self.use_biword = data.use_bigram self.hidden_dim = data.HP_hidden_dim self.gaz_alphabet = data.gaz_alphabet self.gaz_emb_dim = data.gaz_emb_dim self.word_emb_dim = data.word_emb_dim self.biword_emb_dim = data.biword_emb_dim self.use_char = data.HP_use_char self.bilstm_flag = data.HP_bilstm self.lstm_layer = data.HP_lstm_layer self.use_count = data.HP_use_count self.num_layer = data.HP_num_layer self.model_type = data.model_type self.use_bert = data.use_bert scale = np.sqrt(3.0 / self.gaz_emb_dim) data.pretrain_gaz_embedding[0, :] = np.random.uniform( -scale, scale, [1, self.gaz_emb_dim]) if self.use_char: scale = np.sqrt(3.0 / self.word_emb_dim) data.pretrain_word_embedding[0, :] = np.random.uniform( -scale, scale, [1, self.word_emb_dim]) self.gaz_embedding = nn.Embedding(data.gaz_alphabet.size(), self.gaz_emb_dim) self.word_embedding = nn.Embedding(data.word_alphabet.size(), self.word_emb_dim) if self.use_biword: self.biword_embedding = nn.Embedding(data.biword_alphabet.size(), self.biword_emb_dim) if data.pretrain_gaz_embedding is not None: self.gaz_embedding.weight.data.copy_( torch.from_numpy(data.pretrain_gaz_embedding)) else: self.gaz_embedding.weight.data.copy_( torch.from_numpy( self.random_embedding(data.gaz_alphabet.size(), self.gaz_emb_dim))) if data.pretrain_word_embedding is not None: self.word_embedding.weight.data.copy_( torch.from_numpy(data.pretrain_word_embedding)) else: self.word_embedding.weight.data.copy_( torch.from_numpy( self.random_embedding(data.word_alphabet.size(), self.word_emb_dim))) if self.use_biword: if data.pretrain_biword_embedding is not None: self.biword_embedding.weight.data.copy_( torch.from_numpy(data.pretrain_biword_embedding)) else: self.biword_embedding.weight.data.copy_( torch.from_numpy( self.random_embedding(data.biword_alphabet.size(), self.word_emb_dim))) char_feature_dim = self.word_emb_dim + 4 * self.gaz_emb_dim if self.use_biword: char_feature_dim += self.biword_emb_dim if self.use_bert: char_feature_dim = char_feature_dim + 768 print('total char_feature_dim {}'.format(char_feature_dim)) ## lstm model if self.model_type == 'lstm': lstm_hidden = self.hidden_dim if self.bilstm_flag: self.hidden_dim *= 2 self.NERmodel = NERmodel(model_type='lstm', input_dim=char_feature_dim, hidden_dim=lstm_hidden, num_layer=self.lstm_layer, biflag=self.bilstm_flag) ## cnn model if self.model_type == 'cnn': self.NERmodel = NERmodel(model_type='cnn', input_dim=char_feature_dim, hidden_dim=self.hidden_dim, num_layer=self.num_layer, dropout=data.HP_dropout, gpu=self.gpu) ## attention model if self.model_type == 'transformer': self.NERmodel = NERmodel(model_type='transformer', input_dim=char_feature_dim, hidden_dim=self.hidden_dim, num_layer=self.num_layer, dropout=data.HP_dropout) self.drop = nn.Dropout(p=data.HP_dropout) self.hidden2tag = nn.Linear(self.hidden_dim, data.label_alphabet_size + 2) self.crf = CRF(data.label_alphabet_size, self.gpu) if self.use_bert: self.bert_encoder = BertModel.from_pretrained( 'transformer_cpt/bert/') for p in self.bert_encoder.parameters(): p.requires_grad = False if self.gpu: self.gaz_embedding = self.gaz_embedding.cuda() self.word_embedding = self.word_embedding.cuda() if self.use_biword: self.biword_embedding = self.biword_embedding.cuda() self.NERmodel = self.NERmodel.cuda() self.hidden2tag = self.hidden2tag.cuda() self.crf = self.crf.cuda() if self.use_bert: self.bert_encoder = self.bert_encoder.cuda()
class GazLSTM(nn.Module): def __init__(self, data): super(GazLSTM, self).__init__() self.gpu = data.HP_gpu self.use_biword = data.use_bigram self.hidden_dim = data.HP_hidden_dim self.gaz_alphabet = data.gaz_alphabet self.gaz_emb_dim = data.gaz_emb_dim self.word_emb_dim = data.word_emb_dim self.biword_emb_dim = data.biword_emb_dim self.use_char = data.HP_use_char self.bilstm_flag = data.HP_bilstm self.lstm_layer = data.HP_lstm_layer self.use_count = data.HP_use_count self.num_layer = data.HP_num_layer self.model_type = data.model_type self.use_bert = data.use_bert scale = np.sqrt(3.0 / self.gaz_emb_dim) data.pretrain_gaz_embedding[0, :] = np.random.uniform( -scale, scale, [1, self.gaz_emb_dim]) if self.use_char: scale = np.sqrt(3.0 / self.word_emb_dim) data.pretrain_word_embedding[0, :] = np.random.uniform( -scale, scale, [1, self.word_emb_dim]) self.gaz_embedding = nn.Embedding(data.gaz_alphabet.size(), self.gaz_emb_dim) self.word_embedding = nn.Embedding(data.word_alphabet.size(), self.word_emb_dim) if self.use_biword: self.biword_embedding = nn.Embedding(data.biword_alphabet.size(), self.biword_emb_dim) if data.pretrain_gaz_embedding is not None: self.gaz_embedding.weight.data.copy_( torch.from_numpy(data.pretrain_gaz_embedding)) else: self.gaz_embedding.weight.data.copy_( torch.from_numpy( self.random_embedding(data.gaz_alphabet.size(), self.gaz_emb_dim))) if data.pretrain_word_embedding is not None: self.word_embedding.weight.data.copy_( torch.from_numpy(data.pretrain_word_embedding)) else: self.word_embedding.weight.data.copy_( torch.from_numpy( self.random_embedding(data.word_alphabet.size(), self.word_emb_dim))) if self.use_biword: if data.pretrain_biword_embedding is not None: self.biword_embedding.weight.data.copy_( torch.from_numpy(data.pretrain_biword_embedding)) else: self.biword_embedding.weight.data.copy_( torch.from_numpy( self.random_embedding(data.biword_alphabet.size(), self.word_emb_dim))) char_feature_dim = self.word_emb_dim + 4 * self.gaz_emb_dim if self.use_biword: char_feature_dim += self.biword_emb_dim if self.use_bert: char_feature_dim = char_feature_dim + 768 print('total char_feature_dim {}'.format(char_feature_dim)) ## lstm model if self.model_type == 'lstm': lstm_hidden = self.hidden_dim if self.bilstm_flag: self.hidden_dim *= 2 self.NERmodel = NERmodel(model_type='lstm', input_dim=char_feature_dim, hidden_dim=lstm_hidden, num_layer=self.lstm_layer, biflag=self.bilstm_flag) ## cnn model if self.model_type == 'cnn': self.NERmodel = NERmodel(model_type='cnn', input_dim=char_feature_dim, hidden_dim=self.hidden_dim, num_layer=self.num_layer, dropout=data.HP_dropout, gpu=self.gpu) ## attention model if self.model_type == 'transformer': self.NERmodel = NERmodel(model_type='transformer', input_dim=char_feature_dim, hidden_dim=self.hidden_dim, num_layer=self.num_layer, dropout=data.HP_dropout) self.drop = nn.Dropout(p=data.HP_dropout) self.hidden2tag = nn.Linear(self.hidden_dim, data.label_alphabet_size + 2) self.crf = CRF(data.label_alphabet_size, self.gpu) if self.use_bert: self.bert_encoder = BertModel.from_pretrained( 'transformer_cpt/bert/') for p in self.bert_encoder.parameters(): p.requires_grad = False if self.gpu: self.gaz_embedding = self.gaz_embedding.cuda() self.word_embedding = self.word_embedding.cuda() if self.use_biword: self.biword_embedding = self.biword_embedding.cuda() self.NERmodel = self.NERmodel.cuda() self.hidden2tag = self.hidden2tag.cuda() self.crf = self.crf.cuda() if self.use_bert: self.bert_encoder = self.bert_encoder.cuda() def get_tags(self, gaz_list, word_inputs, biword_inputs, layer_gaz, gaz_count, gaz_chars, gaz_mask_input, gazchar_mask_input, mask, word_seq_lengths, batch_bert, bert_mask): batch_size = word_inputs.size()[0] seq_len = word_inputs.size()[1] max_gaz_num = layer_gaz.size(-1) gaz_match = [] word_embs = self.word_embedding(word_inputs) if self.use_biword: biword_embs = self.biword_embedding(biword_inputs) word_embs = torch.cat([word_embs, biword_embs], dim=-1) if self.model_type != 'transformer': word_inputs_d = self.drop(word_embs) # (b,l,we) else: word_inputs_d = word_embs if self.use_char: gazchar_embeds = self.word_embedding(gaz_chars) gazchar_mask = gazchar_mask_input.unsqueeze(-1).repeat( 1, 1, 1, 1, 1, self.word_emb_dim) gazchar_embeds = gazchar_embeds.data.masked_fill_( gazchar_mask.data, 0) # (b,l,4,gl,cl,ce) # gazchar_mask_input:(b,l,4,gl,cl) gaz_charnum = (gazchar_mask_input == 0).sum( dim=-1, keepdim=True).float() # (b,l,4,gl,1) gaz_charnum = gaz_charnum + (gaz_charnum == 0).float() gaz_embeds = gazchar_embeds.sum(-2) / gaz_charnum # (b,l,4,gl,ce) if self.model_type != 'transformer': gaz_embeds = self.drop(gaz_embeds) else: gaz_embeds = gaz_embeds else: # use gaz embedding gaz_embeds = self.gaz_embedding(layer_gaz) if self.model_type != 'transformer': gaz_embeds_d = self.drop(gaz_embeds) else: gaz_embeds_d = gaz_embeds gaz_mask = gaz_mask_input.unsqueeze(-1).repeat( 1, 1, 1, 1, self.gaz_emb_dim) gaz_embeds = gaz_embeds_d.data.masked_fill_( gaz_mask.data, 0) # (b,l,4,g,ge) ge:gaz_embed_dim if self.use_count: count_sum = torch.sum(gaz_count, dim=3, keepdim=True) # (b,l,4,gn) count_sum = torch.sum(count_sum, dim=2, keepdim=True) # (b,l,1,1) weights = gaz_count.div(count_sum) # (b,l,4,g) weights = weights * 4 weights = weights.unsqueeze(-1) gaz_embeds = weights * gaz_embeds # (b,l,4,g,e) gaz_embeds = torch.sum(gaz_embeds, dim=3) # (b,l,4,e) else: gaz_num = (gaz_mask_input == 0).sum( dim=-1, keepdim=True).float() # (b,l,4,1) gaz_embeds = gaz_embeds.sum(-2) / gaz_num # (b,l,4,ge)/(b,l,4,1) gaz_embeds_cat = gaz_embeds.view(batch_size, seq_len, -1) # (b,l,4*ge) word_input_cat = torch.cat([word_inputs_d, gaz_embeds_cat], dim=-1) # (b,l,we+4*ge) ### cat bert feature if self.use_bert: seg_id = torch.zeros(bert_mask.size()).long().cuda() outputs = self.bert_encoder(batch_bert, bert_mask, seg_id) outputs = outputs[0][:, 1:-1, :] word_input_cat = torch.cat([word_input_cat, outputs], dim=-1) feature_out_d = self.NERmodel(word_input_cat) tags = self.hidden2tag(feature_out_d) return tags, gaz_match def neg_log_likelihood_loss(self, gaz_list, word_inputs, biword_inputs, word_seq_lengths, layer_gaz, gaz_count, gaz_chars, gaz_mask, gazchar_mask, mask, batch_label, batch_bert, bert_mask): tags, _ = self.get_tags(gaz_list, word_inputs, biword_inputs, layer_gaz, gaz_count, gaz_chars, gaz_mask, gazchar_mask, mask, word_seq_lengths, batch_bert, bert_mask) total_loss = self.crf.neg_log_likelihood_loss(tags, mask, batch_label) scores, tag_seq = self.crf._viterbi_decode(tags, mask) return total_loss, tag_seq def forward(self, gaz_list, word_inputs, biword_inputs, word_seq_lengths, layer_gaz, gaz_count, gaz_chars, gaz_mask, gazchar_mask, mask, batch_bert, bert_mask): tags, gaz_match = self.get_tags(gaz_list, word_inputs, biword_inputs, layer_gaz, gaz_count, gaz_chars, gaz_mask, gazchar_mask, mask, word_seq_lengths, batch_bert, bert_mask) scores, tag_seq = self.crf._viterbi_decode(tags, mask) return tag_seq, gaz_match
class ERENet(nn.Module): """ ERENet : entity relation jointed extraction """ def __init__(self, args, word_emb, ent_conf, spo_conf): print('mhs using only char2v+w2v mixed and word_emb is freeze ') super(ERENet, self).__init__() self.max_len = args.max_len self.word_emb = nn.Embedding.from_pretrained(torch.tensor(word_emb, dtype=torch.float32), freeze=True, padding_idx=0) self.char_emb = nn.Embedding(num_embeddings=args.char_vocab_size, embedding_dim=args.char_emb_size, padding_idx=0) self.word_convert_char = nn.Linear(args.word_emb_size, args.char_emb_size, bias=False) self.classes_num = len(spo_conf) self.first_sentence_encoder = SentenceEncoder(args, args.char_emb_size) # self.second_sentence_encoder = SentenceEncoder(args, args.hidden_size) # self.token_entity_emb = nn.Embedding(num_embeddings=2, embedding_dim=config.hidden_size, # padding_idx=0) self.encoder_layer = TransformerEncoderLayer(args.hidden_size * 2, nhead=3) self.transformer_encoder = TransformerEncoder(self.encoder_layer, num_layers=1) self.LayerNorm = ConditionalLayerNorm(args.hidden_size * 2, eps=1e-12) # self.subject_dense = nn.Linear(args.hidden_size * 2, 2) self.ent_emission = nn.Linear(args.hidden_size * 2, len(ent_conf)) self.ent_crf = CRF(len(ent_conf), batch_first=True) self.emission = nn.Linear(args.hidden_size * 2, len(spo_conf)) self.crf = CRF(len(spo_conf), batch_first=True) self.loss_fct = nn.BCEWithLogitsLoss(reduction='none') def forward(self, q_ids=None, char_ids=None, word_ids=None, token_type_ids=None, subject_ids=None, subject_labels=None, object_labels=None, eval_file=None, is_eval=False): mask = char_ids != 0 seq_mask = char_ids.eq(0) char_emb = self.char_emb(char_ids) word_emb = self.word_convert_char(self.word_emb(word_ids)) # word_emb = self.word_emb(word_ids) emb = char_emb + word_emb # emb = char_emb # subject_encoder = sent_encoder + self.token_entity_emb(token_type_id) sent_encoder = self.first_sentence_encoder(emb, seq_mask) ent_emission = self.ent_emission(sent_encoder) if not is_eval: # subject_encoder = self.token_entity_emb(token_type_ids) # context_encoder = bert_encoder + subject_encoder sub_start_encoder = batch_gather(sent_encoder, subject_ids[:, 0]) sub_end_encoder = batch_gather(sent_encoder, subject_ids[:, 1]) subject = torch.cat([sub_start_encoder, sub_end_encoder], 1) context_encoder = self.LayerNorm(sent_encoder, subject) context_encoder = self.transformer_encoder(context_encoder.transpose(1, 0), src_key_padding_mask=seq_mask).transpose(0, 1) ent_loss = -self.ent_crf(ent_emission, subject_labels, mask=mask, reduction='mean') emission = self.emission(context_encoder) po_loss = -self.crf(emission, object_labels, mask=mask, reduction='mean') loss = ent_loss + po_loss return loss else: subject_preds = self.ent_crf.decode(emissions=ent_emission, mask=mask) answer_list = list() for qid, sub_pred in zip(q_ids.cpu().numpy(), subject_preds): seq_len = min(len(eval_file[qid].context), self.max_len) tag_list = list() j = 0 while j < seq_len: end = j flag = True if sub_pred[j] == 1: start = j for k in range(start + 1, seq_len): if sub_pred[k] != sub_pred[start] + 1: end = k - 1 flag = False break if flag: end = seq_len - 1 tag_list.append((start, end)) j = end + 1 answer_list.append(tag_list) qid_ids, sent_encoders, pass_ids, subject_ids, token_type_ids = [], [], [], [], [] for i, subjects in enumerate(answer_list): if subjects: qid = q_ids[i].unsqueeze(0).expand(len(subjects)) pass_tensor = char_ids[i, :].unsqueeze(0).expand(len(subjects), char_ids.size(1)) new_sent_encoder = sent_encoder[i, :, :].unsqueeze(0).expand(len(subjects), sent_encoder.size(1), sent_encoder.size(2)) token_type_id = torch.zeros((len(subjects), char_ids.size(1)), dtype=torch.long) for index, (start, end) in enumerate(subjects): token_type_id[index, start:end + 1] = 1 qid_ids.append(qid) pass_ids.append(pass_tensor) subject_ids.append(torch.tensor(subjects, dtype=torch.long)) sent_encoders.append(new_sent_encoder) token_type_ids.append(token_type_id) if len(qid_ids) == 0: # print('len(qid_list)==0:') subject_ids = torch.zeros(1, 2).long().to(sent_encoder.device) qid_tensor = torch.tensor([-1], dtype=torch.long).to(sent_encoder.device) po_tensor = torch.zeros(1, sent_encoder.size(1)).long().to(sent_encoder.device) return qid_tensor, subject_ids, po_tensor qids = torch.cat(qid_ids).to(sent_encoder.device) pass_ids = torch.cat(pass_ids).to(sent_encoder.device) sent_encoders = torch.cat(sent_encoders).to(sent_encoder.device) # token_type_ids = torch.cat(token_type_ids).to(bert_encoder.device) subject_ids = torch.cat(subject_ids).to(sent_encoder.device) flag = False split_heads = 1024 sent_encoders_ = torch.split(sent_encoders, split_heads, dim=0) pass_ids_ = torch.split(pass_ids, split_heads, dim=0) # token_type_ids_ = torch.split(token_type_ids, split_heads, dim=0) subject_encoder_ = torch.split(subject_ids, split_heads, dim=0) po_preds = list() for i in range(len(subject_encoder_)): sent_encoders = sent_encoders_[i] # token_type_ids = token_type_ids_[i] pass_ids = pass_ids_[i] subject_encoder = subject_encoder_[i] if sent_encoders.size(0) == 1: flag = True sent_encoders = sent_encoders.expand(2, sent_encoders.size(1), sent_encoders.size(2)) subject_encoder = subject_encoder.expand(2, subject_encoder.size(1)) pass_ids = pass_ids.expand(2, pass_ids.size(1)) sub_start_encoder = batch_gather(sent_encoders, subject_encoder[:, 0]) sub_end_encoder = batch_gather(sent_encoders, subject_encoder[:, 1]) subject = torch.cat([sub_start_encoder, sub_end_encoder], 1) context_encoder = self.LayerNorm(sent_encoders, subject) context_encoder = self.transformer_encoder(context_encoder.transpose(1, 0), src_key_padding_mask=pass_ids.eq(0)).transpose(0, 1) emission = self.emission(context_encoder) po_pred = self.crf.decode(emissions=emission, mask=(pass_ids != 0)) max_len = pass_ids.size(1) temp_tag = copy.deepcopy(po_pred) for line in temp_tag: line.extend([0] * (max_len - len(line))) # TODO:check po_pred = torch.tensor(temp_tag).to(emission.device) if flag: po_pred = po_pred[1, :].unsqueeze(0) po_preds.append(po_pred) po_tensor = torch.cat(po_preds).to(qids.device) # print(subject_ids.device) # print(po_tensor.device) # print(qids.shape) # print(subject_ids.shape) # print(po_tensor.shape) return qids, subject_ids, po_tensor
class ERENet(nn.Module): """ ERENet : entity relation extraction """ def __init__(self, args, word_emb): super(ERENet, self).__init__() print('mhs with w2v') if args.activation.lower() == 'relu': self.activation = nn.ReLU() elif args.activation.lower() == 'tanh': self.activation = nn.Tanh() self.word_emb = nn.Embedding.from_pretrained(torch.tensor( word_emb, dtype=torch.float32), freeze=True, padding_idx=0) self.word_convert_char = nn.Linear(args.word_emb_size, args.char_emb_size, bias=False) self.char_emb = nn.Embedding(num_embeddings=args.char_vocab_size, embedding_dim=args.char_emb_size, padding_idx=0) self.rel_emb = nn.Embedding(num_embeddings=len(BAIDU_RELATION), embedding_dim=args.rel_emb_size) self.ent_emb = nn.Embedding(num_embeddings=len(BAIDU_ENTITY), embedding_dim=args.ent_emb_size) self.sentence_encoder = SentenceEncoder(args, args.char_emb_size) self.emission = nn.Linear(args.hidden_size * 2, len(BAIDU_ENTITY)) self.crf = CRF(len(BAIDU_ENTITY), batch_first=True) self.selection_u = nn.Linear(2 * args.hidden_size + args.ent_emb_size, args.rel_emb_size) self.selection_v = nn.Linear(2 * args.hidden_size + args.ent_emb_size, args.rel_emb_size) self.selection_uv = nn.Linear(2 * args.rel_emb_size, args.rel_emb_size) def forward(self, char_ids=None, word_ids=None, label_ids=None, spo_ids=None, is_eval=False): # Entity Extraction mask = char_ids.eq(0) char_emb = self.char_emb(char_ids) word_emb = self.word_convert_char(self.word_emb(word_ids)) emb = char_emb + word_emb sent_encoder = self.sentence_encoder(emb, mask) bio_mask = char_ids != 0 emission = self.emission(sent_encoder) # TODO:check ent_pre = self.entity_decoder(bio_mask, emission, max_len=sent_encoder.size(1)) # Relation Extraction if is_eval: ent_encoder = self.ent_emb(ent_pre) else: ent_encoder = self.ent_emb(label_ids) rel_encoder = torch.cat((sent_encoder, ent_encoder), dim=2) B, L, H = rel_encoder.size() u = self.activation(self.selection_u(rel_encoder)).unsqueeze(1).expand( B, L, L, -1) v = self.activation(self.selection_v(rel_encoder)).unsqueeze(2).expand( B, L, L, -1) uv = self.activation(self.selection_uv(torch.cat((u, v), dim=-1))) selection_logits = torch.einsum('bijh,rh->birj', uv, self.rel_emb.weight) if is_eval: return ent_pre, selection_logits else: crf_loss = -self.crf( emission, label_ids, mask=bio_mask, reduction='mean') selection_loss = self.masked_BCEloss(bio_mask, selection_logits, spo_ids) loss = crf_loss + selection_loss return loss def masked_BCEloss(self, mask, selection_logits, selection_gold): # batch x seq x rel x seq selection_mask = (mask.unsqueeze(2) * mask.unsqueeze(1)).unsqueeze(2).expand( -1, -1, len(BAIDU_RELATION), -1) selection_loss = F.binary_cross_entropy_with_logits(selection_logits, selection_gold, reduction='none') selection_loss = selection_loss.masked_select(selection_mask).sum() selection_loss /= mask.sum() return selection_loss def entity_decoder(self, bio_mask, emission, max_len): decoded_tag = self.crf.decode(emissions=emission, mask=bio_mask) temp_tag = copy.deepcopy(decoded_tag) for line in temp_tag: line.extend([0] * (max_len - len(line))) # TODO:check ent_pre = torch.tensor(temp_tag).to(emission.device) # print('entity predict embedding device is {}'.format(ent_pre.device)) return ent_pre