def __init__(self, bert_embedding, label_size, vocabs, after_bert, use_pos_tag=True): super().__init__() self.after_bert = after_bert self.bert_embedding = bert_embedding self.label_size = label_size self.vocabs = vocabs self.hidden_size = bert_embedding._embed_size self.use_pos_tag = use_pos_tag self.pos_feats_size = 0 if self.use_pos_tag: self.pos_embed_size = len(list(vocabs['pos_tag'])) self.pos_feats_size = 20 self.pos_embedding = nn.Embedding(self.pos_embed_size, self.pos_feats_size) if self.after_bert == 'lstm': self.lstm = nn.LSTM( bert_embedding._embed_size + self.pos_feats_size, (bert_embedding._embed_size + self.pos_feats_size) // 2, bidirectional=True, num_layers=2, ) self.output = nn.Linear(self.hidden_size + self.pos_feats_size, self.label_size) self.dropout = MyDropout(0.2) # self.crf = get_crf_zero_init(self.label_size) self.crf = CRF(num_tags=self.label_size, batch_first=True)
def __init__(self, config, batch_size, is_training, num_labels=11, use_crf=False, tag_to_index=None, dropout_prob=0.0, use_one_hot_embeddings=False): super(BertNER, self).__init__() self.bert = BertNERModel(config, is_training, num_labels, use_crf, dropout_prob, use_one_hot_embeddings) if use_crf: if not tag_to_index: raise Exception("The dict for tag-index mapping should be provided for CRF.") from src.CRF import CRF self.loss = CRF(tag_to_index, batch_size, config.seq_length, is_training) else: self.loss = CrossEntropyCalculation(is_training) self.num_labels = num_labels self.use_crf = use_crf
class BERT_SeqLabel(nn.Module): def __init__(self, bert_embedding, label_size, vocabs, after_bert, use_pos_tag=True): super().__init__() self.after_bert = after_bert self.bert_embedding = bert_embedding self.label_size = label_size self.vocabs = vocabs self.hidden_size = bert_embedding._embed_size self.use_pos_tag = use_pos_tag self.pos_feats_size = 0 if self.use_pos_tag: self.pos_embed_size = len(list(vocabs['pos_tag'])) self.pos_feats_size = 20 self.pos_embedding = nn.Embedding(self.pos_embed_size, self.pos_feats_size) if self.after_bert == 'lstm': self.lstm = nn.LSTM( bert_embedding._embed_size + self.pos_feats_size, (bert_embedding._embed_size + self.pos_feats_size) // 2, bidirectional=True, num_layers=2, ) self.output = nn.Linear(self.hidden_size + self.pos_feats_size, self.label_size) self.dropout = MyDropout(0.2) # self.crf = get_crf_zero_init(self.label_size) self.crf = CRF(num_tags=self.label_size, batch_first=True) def forward(self, lattice, bigrams, seq_len, lex_num, pos_s, pos_e, pos_tag, target=None, chars_target=None): batch_size = lattice.size(0) max_seq_len_and_lex_num = lattice.size(1) max_seq_len = bigrams.size(1) words = lattice[:, :max_seq_len] mask = seq_len_to_mask(seq_len).bool() words.masked_fill_((~mask), self.vocabs['lattice'].padding_idx) encoded = self.bert_embedding(words) if self.use_pos_tag: pos_embed = self.pos_embedding(pos_tag) encoded = torch.cat([encoded, pos_embed], dim=-1) if self.after_bert == 'lstm': encoded, _ = self.lstm(encoded, seq_len) encoded = self.dropout(encoded) pred = self.output(encoded) if self.training: # loss = self.crf(pred, target, mask).mean(dim=0) loss = self.crf(emissions=pred, tags=target, mask=mask).mean(dim=0) return {'loss': -loss} else: pred = self.crf.decode(emissions=pred, mask=mask).squeeze(0) # pred, path = self.crf.viterbi_decode(pred, mask) # print(pred.shape) result = {'pred': pred} return result