示例#1
0
class SeqModel(nn.Module):
    def __init__(self, data):
        super(SeqModel, self).__init__()
        self.data = data
        self.use_crf = data.use_crf
        print("build network...")
        print("word feature extractor: ", data.word_feature_extractor)

        self.gpu = data.HP_gpu
        self.average_batch = data.average_batch_loss
        # opinion 和 evidence 分开抽
        label_size = data.label_alphabet_size
        self.word_hidden = WordSequence(data)
        if self.use_crf:
            self.word_crf = CRF(label_size, batch_first=True)
            if self.gpu:
                self.word_crf = self.word_crf.cuda()

    def neg_log_likelihood_loss(self, word_inputs, word_seq_lengths,
                                batch_label, mask, input_label_seq_tensor):
        lstm_outs = self.word_hidden(word_inputs, word_seq_lengths,
                                     input_label_seq_tensor)
        # lstm_outs(batch_size,sentence_length,tag_size)
        batch_size = word_inputs.size(0)
        if self.use_crf:
            mask = mask.byte()
            loss = (-self.word_crf(lstm_outs, batch_label, mask))
            tag_seq = self.word_crf.decode(lstm_outs, mask)
        else:
            loss_function = nn.NLLLoss()
            seq_len = lstm_outs.size(1)
            lstm_outs = lstm_outs.view(batch_size * seq_len, -1)
            score = F.log_softmax(lstm_outs, 1)
            loss = loss_function(
                score,
                batch_label.contiguous().view(batch_size * seq_len))
            _, tag_seq = torch.max(score, 1)
            tag_seq = tag_seq.view(batch_size, seq_len)
        return loss, tag_seq

    def evaluate(self, word_inputs, word_seq_lengths, mask,
                 input_label_seq_tensor):
        lstm_outs = self.word_hidden(word_inputs, word_seq_lengths,
                                     input_label_seq_tensor)
        if self.use_crf:
            mask = mask.byte()
            tag_seq = self.word_crf.decode(lstm_outs, mask)
        else:
            batch_size = word_inputs.size(0)
            seq_len = lstm_outs.size(1)
            lstm_outs = lstm_outs.view(batch_size * seq_len, -1)
            _, tag_seq = torch.max(lstm_outs, 1)
            tag_seq = mask.long() * tag_seq.view(batch_size, seq_len)
        return tag_seq

    def forward(self, word_inputs, word_seq_lengths, mask,
                input_label_seq_tensor):
        return self.evaluate(word_inputs, word_seq_lengths, mask,
                             input_label_seq_tensor)
示例#2
0
class GazLSTM(nn.Module):
    def __init__(self, data):
        super(GazLSTM, self).__init__()

        self.gpu = data.HP_gpu
        self.use_biword = data.use_bigram
        self.hidden_dim = data.HP_hidden_dim
        self.gaz_alphabet = data.gaz_alphabet
        self.gaz_emb_dim = data.gaz_emb_dim
        self.word_emb_dim = data.word_emb_dim
        self.biword_emb_dim = data.biword_emb_dim
        self.use_char = data.HP_use_char
        self.bilstm_flag = data.HP_bilstm
        self.lstm_layer = data.HP_lstm_layer
        self.use_count = data.HP_use_count
        self.num_layer = data.HP_num_layer
        self.model_type = data.model_type
        self.use_bert = data.use_bert

        # self.use_gazcount = data.use_gazcount
        #设置是否使用词典
        self.use_dictionary = data.use_dictionary
        self.simi_dic_emb = data.simi_dic_emb
        self.simi_dic_dim = data.simi_dic_dim

        scale = np.sqrt(3.0 / self.gaz_emb_dim)
        data.pretrain_gaz_embedding[0, :] = np.random.uniform(
            -scale, scale, [1, self.gaz_emb_dim])

        if self.use_char:
            scale = np.sqrt(3.0 / self.word_emb_dim)
            data.pretrain_word_embedding[0, :] = np.random.uniform(
                -scale, scale, [1, self.word_emb_dim])

        self.gaz_embedding = nn.Embedding(data.gaz_alphabet.size(),
                                          self.gaz_emb_dim)  #初始化gaz随机矩阵
        self.word_embedding = nn.Embedding(data.word_alphabet.size(),
                                           self.word_emb_dim)  #初始化word随机矩阵
        if self.use_biword:
            self.biword_embedding = nn.Embedding(data.biword_alphabet.size(),
                                                 self.biword_emb_dim)

        if data.pretrain_gaz_embedding is not None:
            self.gaz_embedding.weight.data.copy_(
                torch.from_numpy(data.pretrain_gaz_embedding)
            )  #将data.pretrain_gaz_embedding的值拷贝到gaz_embedding中
        else:
            self.gaz_embedding.weight.data.copy_(
                torch.from_numpy(
                    self.random_embedding(data.gaz_alphabet.size(),
                                          self.gaz_emb_dim)))

        if data.pretrain_word_embedding is not None:
            self.word_embedding.weight.data.copy_(
                torch.from_numpy(data.pretrain_word_embedding))
        else:
            self.word_embedding.weight.data.copy_(
                torch.from_numpy(
                    self.random_embedding(data.word_alphabet.size(),
                                          self.word_emb_dim)))
        if self.use_biword:
            if data.pretrain_biword_embedding is not None:
                self.biword_embedding.weight.data.copy_(
                    torch.from_numpy(data.pretrain_biword_embedding))
            else:
                self.biword_embedding.weight.data.copy_(
                    torch.from_numpy(
                        self.random_embedding(data.biword_alphabet.size(),
                                              self.word_emb_dim)))

        use_gazcount = True
        #字符的特征纬度
        char_feature_dim = self.word_emb_dim + 4 * self.gaz_emb_dim
        if self.use_dictionary:
            if use_gazcount:
                char_feature_dim += self.simi_dic_dim
            else:
                char_feature_dim = self.word_emb_dim  #+ self.simi_dic_dim

        if self.use_biword:
            char_feature_dim += self.biword_emb_dim

        if self.use_bert:
            char_feature_dim = char_feature_dim + 768

        ## lstm model
        if self.model_type == 'lstm':
            lstm_hidden = self.hidden_dim
            if self.bilstm_flag:
                self.hidden_dim *= 2

            self.NERmodel = NERmodel(model_type='lstm',
                                     input_dim=char_feature_dim,
                                     hidden_dim=lstm_hidden,
                                     num_layer=self.lstm_layer,
                                     biflag=self.bilstm_flag)

        ## cnn model
        if self.model_type == 'cnn':
            self.NERmodel = NERmodel(model_type='cnn',
                                     input_dim=char_feature_dim,
                                     hidden_dim=self.hidden_dim,
                                     num_layer=self.num_layer,
                                     dropout=data.HP_dropout,
                                     gpu=self.gpu)

        ## attention model
        if self.model_type == 'transformer':
            self.NERmodel = NERmodel(model_type='transformer',
                                     input_dim=char_feature_dim,
                                     hidden_dim=self.hidden_dim,
                                     num_layer=self.num_layer,
                                     dropout=data.HP_dropout)

        self.drop = nn.Dropout(p=data.HP_dropout)  #按照0.5的概率改为零
        self.hidden2tag = nn.Linear(self.hidden_dim,
                                    data.label_alphabet_size + 2)
        self.crf = CRF(data.label_alphabet_size, self.gpu)

        if self.use_bert:
            self.bert_encoder = BertModel.from_pretrained('bert-base-chinese')
            for p in self.bert_encoder.parameters():
                p.requires_grad = False

        if self.gpu:
            self.gaz_embedding = self.gaz_embedding.cuda()
            self.word_embedding = self.word_embedding.cuda()
            if self.use_biword:
                self.biword_embedding = self.biword_embedding.cuda()
            self.NERmodel = self.NERmodel.cuda()
            self.hidden2tag = self.hidden2tag.cuda()
            self.crf = self.crf.cuda()

            if self.use_bert:
                self.bert_encoder = self.bert_encoder.cuda()

    #获取嵌入的函数
    def get_tags(self, gaz_list, word_inputs, biword_inputs, layer_gaz,
                 gaz_count, gaz_chars, gaz_mask_input, gazchar_mask_input,
                 mask, word_seq_lengths, batch_bert, bert_mask, simi_value):
        use_gazcount = True
        batch_size = word_inputs.size()[0]
        seq_len = word_inputs.size()[1]
        max_gaz_num = layer_gaz.size(-1)
        gaz_match = []

        word_embs = self.word_embedding(word_inputs)

        if self.use_biword:
            biword_embs = self.biword_embedding(biword_inputs)
            word_embs = torch.cat([word_embs, biword_embs], dim=-1)

        if self.model_type != 'transformer':
            word_inputs_d = self.drop(word_embs)  #(b,l,we)
            print(type(word_inputs_d))
        else:
            word_inputs_d = word_embs

        if self.use_char:
            gazchar_embeds = self.word_embedding(gaz_chars)

            gazchar_mask = gazchar_mask_input.unsqueeze(-1).repeat(
                1, 1, 1, 1, 1, self.word_emb_dim)
            gazchar_embeds = gazchar_embeds.data.masked_fill_(
                gazchar_mask.data, 0)  #(b,l,4,gl,cl,ce)

            # gazchar_mask_input:(b,l,4,gl,cl)
            gaz_charnum = (gazchar_mask_input == 0).sum(
                dim=-1, keepdim=True).float()  #(b,l,4,gl,1)
            gaz_charnum = gaz_charnum + (gaz_charnum == 0).float()
            gaz_embeds = gazchar_embeds.sum(-2) / gaz_charnum  #(b,l,4,gl,ce)

            if self.model_type != 'transformer':
                gaz_embeds = self.drop(gaz_embeds)
            else:
                gaz_embeds = gaz_embeds

        else:  #use gaz embedding
            gaz_embeds = self.gaz_embedding(layer_gaz)

            if self.model_type != 'transformer':
                gaz_embeds_d = self.drop(gaz_embeds)
            else:
                gaz_embeds_d = gaz_embeds

            gaz_mask = gaz_mask_input.unsqueeze(-1).repeat(
                1, 1, 1, 1, self.gaz_emb_dim)

            gaz_embeds = gaz_embeds_d.data.masked_fill_(
                gaz_mask.data, 0)  #(b,l,4,g,ge)  ge:gaz_embed_dim

        if self.use_count:
            count_sum = torch.sum(gaz_count, dim=3, keepdim=True)  #(b,l,4,gn)
            count_sum = torch.sum(count_sum, dim=2, keepdim=True)  #(b,l,1,1)

            weights = gaz_count.div(count_sum)  #(b,l,4,g)
            weights = weights * 4
            weights = weights.unsqueeze(-1)
            gaz_embeds = weights * gaz_embeds  #(b,l,4,g,e)
            gaz_embeds = torch.sum(gaz_embeds, dim=3)  #(b,l,4,e)

        else:
            gaz_num = (gaz_mask_input == 0).sum(
                dim=-1, keepdim=True).float()  #(b,l,4,1)
            gaz_embeds = gaz_embeds.sum(-2) / gaz_num  #(b,l,4,ge)/(b,l,4,1)

        gaz_embeds_cat = gaz_embeds.view(batch_size, seq_len, -1)  #(b,l,4*ge)
        print(type(gaz_embeds_cat))
        if self.use_dictionary:  #拼接词典的向量
            simi_embeds = []
            for key in simi_value:
                for value in key:
                    simi = [value for i in range(self.simi_dic_dim)]
                    simi_embeds.append(simi)
            print(simi_embeds)
            simi_embeds = torch.Tensor(simi_embeds)
            simi_embeds = simi_embeds.cuda()
            print(simi_embeds)
            simi_embeds_cat = simi_embeds.view(batch_size, seq_len, -1)
            self.simi_dic_emb = simi_embeds
            if use_gazcount:
                word_input_cat = torch.cat(
                    [word_inputs_d, gaz_embeds_cat, simi_embeds_cat], dim=-1)
            else:
                #word_input_cat = torch.cat([word_inputs_d, simi_embeds_cat],dim = -1)
                word_input_cat = torch.cat([word_inputs_d], dim=-1)

        else:
            word_input_cat = torch.cat([word_inputs_d, gaz_embeds_cat],
                                       dim=-1)  #(b,l,we+4*ge)
        #print(len(word_input_cat))
        # if only_char:
        #     word_input_cat= torch.cat()

        ### cat bert feature
        if self.use_bert:
            seg_id = torch.zeros(bert_mask.size()).long().cuda()
            outputs = self.bert_encoder(batch_bert, bert_mask, seg_id)
            outputs = outputs[0][:, 1:-1, :]
            word_input_cat = torch.cat([word_input_cat, outputs], dim=-1)

        feature_out_d = self.NERmodel(word_input_cat)

        tags = self.hidden2tag(feature_out_d)

        return tags, gaz_match

    def neg_log_likelihood_loss(self, gaz_list, word_inputs, biword_inputs,
                                word_seq_lengths, layer_gaz, gaz_count,
                                gaz_chars, gaz_mask, gazchar_mask, mask,
                                batch_label, batch_bert, bert_mask,
                                simi_value):

        tags, _ = self.get_tags(gaz_list, word_inputs, biword_inputs,
                                layer_gaz, gaz_count, gaz_chars, gaz_mask,
                                gazchar_mask, mask, word_seq_lengths,
                                batch_bert, bert_mask, simi_value)

        total_loss = self.crf.neg_log_likelihood_loss(tags, mask, batch_label)
        scores, tag_seq = self.crf._viterbi_decode(tags, mask)

        return total_loss, tag_seq

    def forward(self, gaz_list, word_inputs, biword_inputs, word_seq_lengths,
                layer_gaz, gaz_count, gaz_chars, gaz_mask, gazchar_mask, mask,
                batch_bert, bert_mask, simi_value):

        tags, gaz_match = self.get_tags(gaz_list, word_inputs, biword_inputs,
                                        layer_gaz, gaz_count, gaz_chars,
                                        gaz_mask, gazchar_mask, mask,
                                        word_seq_lengths, batch_bert,
                                        bert_mask, simi_value)

        scores, tag_seq = self.crf._viterbi_decode(tags, mask)

        return tag_seq, gaz_match
示例#3
0
class GazLSTM(nn.Module):
    def __init__(self, data):
        super(GazLSTM, self).__init__()
        self.gpu = data.HP_gpu
        self.use_biword = data.use_bigram
        self.hidden_dim = data.HP_hidden_dim
        self.gaz_alphabet = data.gaz_alphabet
        self.gaz_emb_dim = data.gaz_emb_dim
        self.word_emb_dim = data.word_emb_dim
        self.biword_emb_dim = data.biword_emb_dim
        self.use_char = data.HP_use_char
        self.bilstm_flag = data.HP_bilstm
        self.lstm_layer = data.HP_lstm_layer
        self.use_count = data.HP_use_count
        self.num_layer = data.HP_num_layer
        self.model_type = data.model_type

        scale = np.sqrt(3.0 / self.gaz_emb_dim)
        data.pretrain_gaz_embedding[0, :] = np.random.uniform(
            -scale, scale, [1, self.gaz_emb_dim])

        if self.use_char:
            scale = np.sqrt(3.0 / self.word_emb_dim)
            data.pretrain_word_embedding[0, :] = np.random.uniform(
                -scale, scale, [1, self.word_emb_dim])

        self.gaz_embedding = nn.Embedding(data.gaz_alphabet.size(),
                                          self.gaz_emb_dim)
        self.word_embedding = nn.Embedding(data.word_alphabet.size(),
                                           self.word_emb_dim)
        if self.use_biword:
            self.biword_embedding = nn.Embedding(data.biword_alphabet.size(),
                                                 self.biword_emb_dim)

        if data.pretrain_gaz_embedding is not None:
            self.gaz_embedding.weight.data.copy_(
                torch.from_numpy(data.pretrain_gaz_embedding))
        else:
            self.gaz_embedding.weight.data.copy_(
                torch.from_numpy(
                    self.random_embedding(data.gaz_alphabet.size(),
                                          self.gaz_emb_dim)))

        if data.pretrain_word_embedding is not None:
            self.word_embedding.weight.data.copy_(
                torch.from_numpy(data.pretrain_word_embedding))
        else:
            self.word_embedding.weight.data.copy_(
                torch.from_numpy(
                    self.random_embedding(data.word_alphabet.size(),
                                          self.word_emb_dim)))
        if self.use_biword:
            if data.pretrain_biword_embedding is not None:
                self.biword_embedding.weight.data.copy_(
                    torch.from_numpy(data.pretrain_biword_embedding))
            else:
                self.biword_embedding.weight.data.copy_(
                    torch.from_numpy(
                        self.random_embedding(data.biword_alphabet.size(),
                                              self.word_emb_dim)))

        char_feature_dim = self.word_emb_dim + 4 * self.gaz_emb_dim
        if self.use_biword:
            char_feature_dim += self.biword_emb_dim

        ## lstm model
        if self.model_type == 'lstm':
            lstm_hidden = self.hidden_dim
            if self.bilstm_flag:
                self.hidden_dim *= 2
            self.NERmodel = NERmodel(model_type='lstm',
                                     input_dim=char_feature_dim,
                                     hidden_dim=lstm_hidden,
                                     num_layer=self.lstm_layer,
                                     biflag=self.bilstm_flag)

        ## cnn model
        if self.model_type == 'cnn':
            self.NERmodel = NERmodel(model_type='cnn',
                                     input_dim=char_feature_dim,
                                     hidden_dim=self.hidden_dim,
                                     num_layer=self.num_layer,
                                     dropout=data.HP_dropout,
                                     gpu=self.gpu)

        ## attention model
        if self.model_type == 'transformer':
            self.NERmodel = NERmodel(model_type='transformer',
                                     input_dim=char_feature_dim,
                                     hidden_dim=self.hidden_dim,
                                     num_layer=self.num_layer,
                                     dropout=data.HP_dropout)

        self.drop = nn.Dropout(p=data.HP_dropout)
        self.hidden2tag = nn.Linear(self.hidden_dim,
                                    data.label_alphabet_size + 2)
        self.crf = CRF(data.label_alphabet_size, self.gpu)

        if self.gpu:
            #self.drop = self.drop.cuda()
            self.gaz_embedding = self.gaz_embedding.cuda()
            self.word_embedding = self.word_embedding.cuda()
            if self.use_biword:
                self.biword_embedding = self.biword_embedding.cuda()
            self.NERmodel = self.NERmodel.cuda()
            self.hidden2tag = self.hidden2tag.cuda()
            self.crf = self.crf.cuda()

    def get_tags(self, gaz_list, word_inputs, biword_inputs, layer_gaz,
                 gaz_count, gaz_chars, gaz_mask_input, gazchar_mask_input,
                 mask):

        batch_size = word_inputs.size()[0]
        seq_len = word_inputs.size()[1]
        max_gaz_num = layer_gaz.size(-1)
        gaz_match = []

        word_embs = self.word_embedding(word_inputs)

        if self.use_biword:
            biword_embs = self.biword_embedding(biword_inputs)
            word_embs = torch.cat([word_embs, biword_embs], dim=-1)

        if self.model_type != 'transformer':
            word_inputs_d = self.drop(word_embs)  #(b,l,we)
        else:
            word_inputs_d = word_embs

        if self.use_char:
            gazchar_embeds = self.word_embedding(gaz_chars)

            gazchar_mask = gazchar_mask_input.unsqueeze(-1).repeat(
                1, 1, 1, 1, 1, self.word_emb_dim)
            gazchar_embeds = gazchar_embeds.data.masked_fill_(
                gazchar_mask.data, 0)  #(b,l,4,gl,cl,ce)

            # gazchar_mask_input:(b,l,4,gl,cl)
            gaz_charnum = (gazchar_mask_input == 0).sum(
                dim=-1, keepdim=True).float()  #(b,l,4,gl,1)
            gaz_charnum = gaz_charnum + (gaz_charnum == 0).float()
            gaz_embeds = gazchar_embeds.sum(-2) / gaz_charnum  #(b,l,4,gl,ce)

            if self.model_type != 'transformer':
                gaz_embeds = self.drop(gaz_embeds)
            else:
                gaz_embeds = gaz_embeds

        else:  #use gaz embedding
            gaz_embeds = self.gaz_embedding(layer_gaz)

            if self.model_type != 'transformer':
                gaz_embeds_d = self.drop(gaz_embeds)
            else:
                gaz_embeds_d = gaz_embeds

            gaz_mask = gaz_mask_input.unsqueeze(-1).repeat(
                1, 1, 1, 1, self.gaz_emb_dim)

            gaz_embeds = gaz_embeds_d.data.masked_fill_(
                gaz_mask.data, 0)  #(b,l,4,g,ge)  ge:gaz_embed_dim

        if self.use_count:
            count_sum = torch.sum(gaz_count, dim=3, keepdim=True)  #(b,l,4,gn)
            count_sum = torch.sum(count_sum, dim=2, keepdim=True)  #(b,l,1,1)

            weights = gaz_count.div(count_sum)  #(b,l,4,g)
            weights = weights * 4
            weights = weights.unsqueeze(-1)
            gaz_embeds = weights * gaz_embeds  #(b,l,4,g,e)
            gaz_embeds = torch.sum(gaz_embeds, dim=3)  #(b,l,4,e)

        else:
            gaz_num = (gaz_mask_input == 0).sum(
                dim=-1, keepdim=True).float()  #(b,l,4,1)
            gaz_embeds = gaz_embeds.sum(-2) / gaz_num  #(b,l,4,ge)/(b,l,4,1)

        gaz_embeds_cat = gaz_embeds.view(batch_size, seq_len, -1)  #(b,l,4*ge)

        word_input_cat = torch.cat([word_inputs_d, gaz_embeds_cat],
                                   dim=-1)  #(b,l,we+4*ge)

        feature_out_d = self.NERmodel(word_input_cat)

        tags = self.hidden2tag(feature_out_d)

        return tags, gaz_match

    def neg_log_likelihood_loss(self, gaz_list, word_inputs, biword_inputs,
                                word_seq_lengths, layer_gaz, gaz_count,
                                gaz_chars, gaz_mask, gazchar_mask, mask,
                                batch_label):

        tags, _ = self.get_tags(gaz_list, word_inputs, biword_inputs,
                                layer_gaz, gaz_count, gaz_chars, gaz_mask,
                                gazchar_mask, mask)

        total_loss = self.crf.neg_log_likelihood_loss(tags, mask, batch_label)
        scores, tag_seq = self.crf._viterbi_decode(tags, mask)

        return total_loss, tag_seq

    def forward(self, gaz_list, word_inputs, biword_inputs, word_seq_lengths,
                layer_gaz, gaz_count, gaz_chars, gaz_mask, gazchar_mask, mask):

        tags, gaz_match = self.get_tags(gaz_list, word_inputs, biword_inputs,
                                        layer_gaz, gaz_count, gaz_chars,
                                        gaz_mask, gazchar_mask, mask)

        scores, tag_seq = self.crf._viterbi_decode(tags, mask)

        return tag_seq, gaz_match
示例#4
0
class JointModel(nn.Module):
    def __init__(self, data):
        super(JointModel, self).__init__()
        self.data = data
        self.use_crf = data.use_crf
        logger.info("build network...")
        logger.info("word feature extractor: %s" % data.word_feature_extractor)
        logger.info("use_cuda: %s" % data.HP_gpu)
        self.gpu = data.HP_gpu
        logger.info("use_crf: %s" % data.use_crf)
        self.average_batch = data.average_batch_loss

        label_size = data.label_alphabet_size
        sentence_size = data.sentence_alphabet_size
        self.word_hidden = JointSequence(data)
        if self.use_crf:
            self.word_crf = CRF(label_size, batch_first=True)
            self.sent_crf = CRF(sentence_size, batch_first=True)
            if self.gpu:
                self.word_crf = self.word_crf.cuda()
                self.sent_crf = self.sent_crf.cuda()

    def neg_log_likelihood_loss(self,
                                word_inputs,
                                word_tensor,
                                word_seq_lengths,
                                batch_label,
                                batch_sent_type,
                                mask,
                                sent_mask,
                                input_label_seq_tensor,
                                input_sent_type_tensor,
                                batch_word_recover,
                                word_perm_idx,
                                need_cat=True,
                                need_embedding=True):
        words_outs, sent_out = self.word_hidden(
            word_inputs, word_tensor, word_seq_lengths, input_label_seq_tensor,
            input_sent_type_tensor, batch_word_recover, word_perm_idx,
            batch_sent_type, need_cat, need_embedding)
        batch_size = words_outs.size(0)
        seq_len = words_outs.size(1)
        if self.use_crf:
            # e_out(batch_size,sentence_length,tag_size)
            words_loss = (-self.word_crf(words_outs, batch_label, mask)) / (
                len(word_seq_lengths) * seq_len)
            words_tag_seq = self.word_crf.decode(words_outs, mask)
            sent_total_loss = -self.sent_crf(
                sent_out, batch_sent_type[batch_word_recover].view(
                    batch_size, 1),
                sent_mask.view(batch_size, 1).byte()) / len(sent_mask)
            sent_tag_seq = self.sent_crf.decode(
                sent_out,
                sent_mask.view(batch_size, 1).byte())
        else:
            loss_function = nn.NLLLoss()
            words_outs = words_outs.view(batch_size * seq_len, -1)
            words_score = F.log_softmax(words_outs, 1)
            words_loss = loss_function(
                words_score,
                batch_label.contiguous().view(batch_size * seq_len))
            _, words_tag_seq = torch.max(words_score, 1)
            words_tag_seq = words_tag_seq.view(batch_size, seq_len)

            sent_out = sent_out.view(batch_size, -1)
            sent_score = F.log_softmax(sent_out, 1)
            sent_total_loss = loss_function(
                sent_score,
                batch_sent_type[batch_word_recover].view(batch_size))
            _, sent_tag_seq = torch.max(sent_score, 1)
        return words_loss, words_tag_seq, sent_total_loss, sent_tag_seq

    def evaluate(self,
                 word_inputs,
                 word_tensor,
                 word_seq_lengths,
                 batch_sent_type,
                 mask,
                 sent_mask,
                 input_label_seq_tensor,
                 input_sent_type_tensor,
                 batch_word_recover,
                 word_perm_idx,
                 need_cat=True,
                 need_embedding=True):
        words_out, sent_out = self.word_hidden(
            word_inputs, word_tensor, word_seq_lengths, input_label_seq_tensor,
            input_sent_type_tensor, batch_word_recover, word_perm_idx,
            batch_sent_type, need_cat, need_embedding)
        batch_size = words_out.size(0)
        seq_len = words_out.size(1)
        if self.use_crf:
            sent_tag_seq = self.sent_crf.decode(
                sent_out,
                sent_mask.view(batch_size, 1).byte())
            # 由于sentence在预测分类时已经恢复了顺序,后面的word顺序还没有恢复,所以此时要继续打乱顺序
            sent_tag_seq = torch.tensor(sent_tag_seq)[word_perm_idx]
            if self.gpu:
                sent_tag_seq = sent_tag_seq.cpu().data.numpy().tolist()
            else:
                sent_tag_seq = sent_tag_seq.data.numpy().tolist()
            words_tag_seq = self.word_crf.decode(words_out, mask)
        else:
            sent_out = sent_out.view(batch_size, -1)
            _, sent_tag_seq = torch.max(sent_out, 1)
            # 由于sentence在预测分类时已经恢复了顺序,后面的word顺序还没有恢复,所以此时要继续打乱顺序
            sent_tag_seq = sent_tag_seq[word_perm_idx]

            words_out = words_out.view(batch_size * seq_len, -1)
            _, words_tag_seq = torch.max(words_out, 1)
            words_tag_seq = mask.long() * words_tag_seq.view(
                batch_size, seq_len)
        return words_tag_seq, sent_tag_seq

    def forward(self,
                word_inputs,
                word_tensor,
                word_seq_lengths,
                mask,
                sent_mask,
                input_label_seq_tensor,
                input_sent_type_tensor,
                batch_word_recover,
                word_perm_idx,
                need_cat=True,
                need_embedding=True):
        batch_size = word_tensor.size(0)
        seq_len = word_tensor.size(1)
        lstm_out, hidden, sent_out, label_embs = self.word_hidden.evaluate_sentence(
            word_inputs, word_tensor, word_seq_lengths, input_label_seq_tensor,
            input_sent_type_tensor, batch_word_recover, need_cat,
            need_embedding)
        lstm_out = torch.cat([
            lstm_out, sent_out[word_perm_idx].expand(
                [lstm_out.size(0),
                 lstm_out.size(1),
                 sent_out.size(-1)])
        ], -1)
        words_outs = self.word_hidden.evaluate_word(lstm_out, hidden,
                                                    word_seq_lengths,
                                                    label_embs)
        if self.use_crf:
            sent_tag_seq = self.sent_crf.decode(
                sent_out,
                sent_mask.view(batch_size, 1).byte())
            # 由于sentence在预测分类时已经恢复了顺序,后面的word顺序还没有恢复,所以此时要继续打乱顺序
            sent_tag_seq = torch.tensor(sent_tag_seq)[word_perm_idx]
            words_tag_seq = self.word_crf.decode(words_outs, mask)
        else:
            sent_out = sent_out.view(batch_size, -1)
            _, sent_tag_seq = torch.max(sent_out, 1)
            # 由于sentence在预测分类时已经恢复了顺序,后面的word顺序还没有恢复,所以此时要继续打乱顺序
            sent_tag_seq = sent_tag_seq[word_perm_idx]
            words_outs = words_outs.view(batch_size * seq_len, -1)
            _, words_tag_seq = torch.max(words_outs, 1)
            words_tag_seq = mask.long() * words_tag_seq.view(
                batch_size, seq_len)
        return words_tag_seq, sent_tag_seq
示例#5
0
文件: CNNmodel.py 项目: rtmaww/LR-CNN
class CNNmodel(nn.Module):
    def __init__(self, data):
        super(CNNmodel, self).__init__()
        self.gpu = data.HP_gpu
        self.use_biword = data.use_bigram
        self.use_posi = data.HP_use_posi
        self.hidden_dim = data.HP_hidden_dim
        self.gaz_alphabet = data.gaz_alphabet
        self.gaz_emb_dim = data.gaz_emb_dim
        self.word_emb_dim = data.word_emb_dim
        self.posi_emb_dim = data.posi_emb_dim
        self.biword_emb_dim = data.biword_emb_dim
        self.rethink_iter = data.HP_rethink_iter

        scale = np.sqrt(3.0 / self.gaz_emb_dim)
        data.pretrain_gaz_embedding[0, :] = np.random.uniform(
            -scale, scale, [1, self.gaz_emb_dim])
        self.gaz_embedding = nn.Embedding(data.gaz_alphabet.size(),
                                          self.gaz_emb_dim)
        self.gaz_embedding.weight.data.copy_(
            torch.from_numpy(data.pretrain_gaz_embedding))

        self.word_embedding = nn.Embedding(data.word_alphabet.size(),
                                           self.word_emb_dim)
        self.word_embedding.weight.data.copy_(
            torch.from_numpy(data.pretrain_word_embedding))

        if data.HP_use_posi:
            data.posi_alphabet_size += 1
            self.position_embedding = nn.Embedding.from_pretrained(
                get_sinusoid_encoding_table(data.posi_alphabet_size,
                                            self.posi_emb_dim),
                freeze=True)

        if self.use_biword:
            self.biword_embedding = nn.Embedding(data.biword_alphabet.size(),
                                                 self.biword_emb_dim)
            self.biword_embedding.weight.data.copy_(
                torch.from_numpy(data.pretrain_biword_embedding))

        self.drop = nn.Dropout(p=data.HP_dropout)
        self.num_layer = data.HP_num_layer

        input_dim = self.word_emb_dim
        if self.use_biword:
            input_dim += self.biword_emb_dim
        if self.use_posi:
            input_dim += self.posi_emb_dim

        self.cnn_layer0 = nn.Conv1d(input_dim,
                                    self.hidden_dim,
                                    kernel_size=1,
                                    padding=0)
        self.cnn_layers = [
            nn.Conv1d(self.hidden_dim,
                      self.hidden_dim,
                      kernel_size=2,
                      padding=0) for i in range(self.num_layer - 1)
        ]
        self.cnn_layers_back = [
            nn.Conv1d(self.hidden_dim,
                      self.hidden_dim,
                      kernel_size=2,
                      padding=0) for i in range(self.num_layer - 1)
        ]
        self.res_cnn_layers = [
            nn.Conv1d(self.hidden_dim,
                      self.hidden_dim,
                      kernel_size=i + 2,
                      padding=0) for i in range(1, self.num_layer - 1)
        ]
        self.res_cnn_layers_back = [
            nn.Conv1d(self.hidden_dim,
                      self.hidden_dim,
                      kernel_size=i + 2,
                      padding=0) for i in range(1, self.num_layer - 1)
        ]

        self.layer_gate = LayerGate(self.hidden_dim, self.gaz_emb_dim)
        self.global_gate = GlobalGate(self.hidden_dim)
        self.exper2gate = nn.Linear(self.hidden_dim, self.hidden_dim * 4)
        self.multiscale_layer = MultiscaleAttention(self.num_layer,
                                                    data.HP_dropout)

        self.hidden2tag = nn.Linear(self.hidden_dim,
                                    data.label_alphabet_size + 2)
        self.crf = CRF(data.label_alphabet_size, self.gpu)

        if self.gpu:
            self.gaz_embedding = self.gaz_embedding.cuda()
            self.word_embedding = self.word_embedding.cuda()
            if self.use_posi:
                self.position_embedding = self.position_embedding.cuda()
            if self.use_biword:
                self.biword_embedding = self.biword_embedding.cuda()
            self.cnn_layer0 = self.cnn_layer0.cuda()
            self.multiscale_layer = self.multiscale_layer.cuda()
            self.hidden2tag = self.hidden2tag.cuda()
            self.layer_gate = self.layer_gate.cuda()
            self.global_gate = self.global_gate.cuda()
            self.crf = self.crf.cuda()
            for i in range(self.num_layer - 1):
                self.cnn_layers[i] = self.cnn_layers[i].cuda()
                self.cnn_layers_back[i] = self.cnn_layers_back[i].cuda()
                if i >= 1:
                    self.res_cnn_layers[i - 1] = self.res_cnn_layers[i -
                                                                     1].cuda()
                    self.res_cnn_layers_back[i - 1] = self.res_cnn_layers_back[
                        i - 1].cuda()

    def get_tags(self, gaz_list, word_inputs, biword_inputs, layer_gaz,
                 gaz_mask_input, mask):

        batch_size = word_inputs.size(0)
        seq_len = word_inputs.size(1)

        word_embs = self.word_embedding(word_inputs)

        if self.use_biword:
            biword_embs = self.biword_embedding(biword_inputs)
            word_embs = torch.cat([word_embs, biword_embs], dim=2)

        if self.use_posi:
            posi_inputs = torch.zeros(batch_size, seq_len).long()

            posi_inputs[:, :] = torch.LongTensor(
                [i + 1 for i in range(seq_len)])
            if self.gpu:
                posi_inputs = posi_inputs.cuda()
            position_embs = self.position_embedding(posi_inputs)
            word_embs = torch.cat([word_embs, position_embs], dim=2)

        word_inputs_d = self.drop(word_embs)
        word_inputs_d = word_inputs_d.transpose(2, 1).contiguous()

        X_pre = self.cnn_layer0(
            word_inputs_d)  #(batch_size,hidden_size,seq_len)
        X_pre = torch.tanh(X_pre)

        X_trans = X_pre.transpose(2, 1).contiguous()

        global_matrix0 = self.global_gate(X_trans)  # G0

        X_list = X_trans.unsqueeze(
            2)  #(batch_size,seq_len,num_layer,hidden_size)

        padding = torch.zeros(batch_size, self.hidden_dim, 1)
        if self.gpu:
            padding = padding.cuda()

        feed_back = None
        for iteration in range(self.rethink_iter):

            global_matrix = global_matrix0
            X_pre = self.drop(X_pre)

            X_pre_padding = torch.cat(
                [X_pre, padding], dim=2)  #(batch_size,hidden_size,seq_len+1)

            X_pre_padding_back = torch.cat([padding, X_pre], dim=2)

            for layer in range(self.num_layer - 1):

                X = self.cnn_layers[layer](
                    X_pre_padding)  # X: (batch_size,hidden_size,seq_len)
                X = torch.tanh(X)

                X_back = self.cnn_layers_back[layer](
                    X_pre_padding_back)  # X: (batch_size,hidden_size,seq_len)
                X_back = torch.tanh(X_back)

                if layer > 0:
                    windowpad = torch.cat([padding for i in range(layer)],
                                          dim=2)
                    X_pre_padding_w = torch.cat([X_pre, windowpad, padding],
                                                dim=2)
                    X_res = self.res_cnn_layers[layer - 1](X_pre_padding_w)
                    X_res = torch.tanh(X_res)

                    X_pre_padding_w_back = torch.cat(
                        [padding, windowpad, X_pre], dim=2)
                    X_res_back = self.res_cnn_layers_back[layer - 1](
                        X_pre_padding_w_back)
                    X_res_back = torch.tanh(X_res_back)

                layer_gaz_back = torch.zeros(batch_size, seq_len).long()

                if seq_len > layer + 1:
                    layer_gaz_back[:, layer + 1:] = layer_gaz[:, :seq_len -
                                                              layer - 1, layer]

                if self.gpu:
                    layer_gaz_back = layer_gaz_back.cuda()

                gazs_embeds = self.gaz_embedding(layer_gaz[:, :, layer])
                gazs_embeds_back = self.gaz_embedding(layer_gaz_back)

                mask_gaz = (mask == 0).unsqueeze(-1).repeat(
                    1, 1, self.gaz_emb_dim)
                gazs_embeds = gazs_embeds.masked_fill(mask_gaz, 0)
                gazs_embeds_back = gazs_embeds_back.masked_fill(mask_gaz, 0)

                gazs_embeds = self.drop(gazs_embeds)
                gazs_embeds_back = self.drop(gazs_embeds_back)

                if layer > 0:  #res
                    X_input = torch.cat([X, X_back, X_res, X_res_back],
                                        dim=-1).transpose(
                                            2, 1).contiguous()  #(b,4l,h)
                    X, X_back, X_res, X_res_back = self.layer_gate(
                        X_input,
                        gazs_embeds,
                        gazs_embeds_back,
                        global_matrix,
                        exper_input=feed_back,
                        gaz_mask=None)
                    X = X + X_back + X_res + X_res_back
                else:
                    X_input = torch.cat([X, X_back, X, X_back],
                                        dim=-1).transpose(
                                            2, 1).contiguous()  #(b,4l,h)
                    X, X_back, _, _ = self.layer_gate(X_input,
                                                      gazs_embeds,
                                                      gazs_embeds_back,
                                                      global_matrix,
                                                      exper_input=feed_back,
                                                      gaz_mask=None)
                    X = X + X_back

                global_matrix = self.global_gate(X, global_matrix)
                if iteration == self.rethink_iter - 1:
                    X_list = torch.cat([X_list, X.unsqueeze(2)], dim=2)
                if layer == self.num_layer - 2:
                    feed_back = X

                X = X.transpose(2, 1).contiguous()
                X_d = self.drop(X)

                X_pre_padding = torch.cat([X_d, padding], dim=2)  #padding

                padding_back = torch.cat(
                    [padding for _ in range(min(layer + 2, seq_len + 1))],
                    dim=2)
                if seq_len > layer + 1:
                    X_pre_padding_back = torch.cat(
                        [padding_back, X_d[:, :, :seq_len - layer - 1]],
                        dim=2)  #(b,h,seqlen+1)
                else:
                    X_pre_padding_back = padding_back

        X_attention = self.multiscale_layer(X_list)
        tags = self.hidden2tag(X_attention)  #(b,l,t)

        return tags

    def neg_log_likelihood_loss(self, gaz_list, word_inputs, biword_inputs,
                                word_seq_lengths, layer_gaz, gaz_mask, mask,
                                batch_label):

        tags = self.get_tags(gaz_list, word_inputs, biword_inputs, layer_gaz,
                             gaz_mask, mask)

        total_loss = self.crf.neg_log_likelihood_loss(tags, mask, batch_label)
        scores, tag_seq = self.crf._viterbi_decode(tags, mask)

        return total_loss, tag_seq

    def forward(self, gaz_list, word_inputs, biword_inputs, word_seq_lengths,
                layer_gaz, gaz_mask, mask):

        tags = self.get_tags(gaz_list, word_inputs, biword_inputs, layer_gaz,
                             gaz_mask, mask)

        scores, tag_seq = self.crf._viterbi_decode(tags, mask)

        return tag_seq