예제 #1
0
    def get_embedding_from_list(self, item_list, embedding_names, initial_embed):
        emb_list = []
        if 'phoc' in embedding_names: 
            phoc_emb = self.phoc_embed(item_list['phoc'])
            if 'dropout_emb' in self.opt:
                emb_list.append(dropout(phoc_emb, p=self.opt['dropout_emb'], training=self.drop_emb))
            else:
                emb_list.append(phoc_emb)
        if 'fasttext' in embedding_names:
            fast_emb = self.fast_embed(item_list['fasttext'])
            if 'PRE_ALIGN_befor_rnn' in self.opt:
                item_list['fasttext_emb'] = fast_emb
            if 'dropout_emb' in self.opt:
                emb_list.append(dropout(fast_emb, p=self.opt['dropout_emb'], training=self.drop_emb))
            else:
                emb_list.append(fast_emb)

        if 'glove' in embedding_names:
            glove_emb = self.glove_embed(item_list['glove'])
            if 'PRE_ALIGN_befor_rnn' in self.opt:
                item_list['glove_emb'] = glove_emb
            if 'dropout_emb' in self.opt:
                emb_list.append(dropout(glove_emb, p=self.opt['dropout_emb'], training=self.drop_emb))
            else:
                emb_list.append(glove_emb)
        for k in ['bert', 'bert_only']:
            if k in embedding_names:
                if 'ModelParallel' in self.opt:
                    bert_cuda = self.bert_cuda
                    main_cuda = self.main_cuda
                    if k == 'bert':
                        if 'fasttext' == initial_embed:
                            bert_output = self.Bert(item_list['bert'], item_list['bert_mask'], item_list['bert_offsets'], item_list['fasttext_mask'].to(bert_cuda), device=main_cuda)
                        else:
                            bert_output = self.Bert(item_list['bert'], item_list['bert_mask'], item_list['bert_offsets'], item_list['glove_mask'].to(bert_cuda), device=main_cuda)
                else:
                    if k == 'bert':
                        if 'fasttext' == initial_embed:
                            bert_output = self.Bert(item_list['bert'], item_list['bert_mask'],
                                                    item_list['bert_offsets'], item_list['fasttext_mask'])
                        else:
                            bert_output = self.Bert(item_list['bert'], item_list['bert_mask'], item_list['bert_offsets'], item_list['glove_mask'])
                if 'BERT_LINEAR_COMBINE' in self.opt:
                    bert_output = self.linear_sum(bert_output, self.alphaBERT, self.gammaBERT)
                emb_list.append(bert_output)
        if 'pos' in embedding_names:
            emb_list.append(
                self.pos_embedding(item_list['pos'])
            )
        if 'ent' in embedding_names:
            emb_list.append(
                self.ent_embedding(item_list['ent'])
            )
        res = torch.cat(emb_list, dim=-1) # final embedding cat
        return res
예제 #2
0
    def linear_sum(self, output, alpha, gamma):
        alpha_softmax = F.softmax(alpha, dim=0)
        for i in range(len(output)):
            t = output[i] * alpha_softmax[i] * gamma
            if i == 0:
                res = t
            else:
                res += t

        res = dropout(res, p=self.opt['dropout_emb'], training=self.drop_emb)
        return res
예제 #3
0
 def linear_sum(self, output, alpha, gamma):
     alpha_softmax = F.softmax(alpha)  # 对alpha权重归一化
     for i in range(len(output)):
         t = output[i] * alpha_softmax[
             i] * gamma  # 第i层的权重系数是alpha_softmax[i] * gamma
         if i == 0:
             res = t
         else:
             res += t
     res = dropout(x=res, p=self.opt['dropout_emb'],
                   training=self.drop_emb)  # Dropout后输出
     return res
예제 #4
0
    def forward(self, x, x_single_mask, x_char, x_char_mask, x_features, x_pos,
                x_ent, x_bert, x_bert_mask, x_bert_offsets, q, q_mask, q_char,
                q_char_mask, q_bert, q_bert_mask, q_bert_offsets, context_len):
        batch_size = q.shape[0]
        x_mask = x_single_mask.expand(batch_size, -1)
        x_word_embed = self.vocab_embed(x).expand(
            batch_size, -1, -1)  # batch x x_len x vocab_dim
        ques_word_embed = self.vocab_embed(q)  # batch x q_len x vocab_dim

        x_input_list = [
            dropout(x_word_embed,
                    p=self.opt['dropout_emb'],
                    training=self.drop_emb)
        ]  # batch x x_len x vocab_dim
        ques_input_list = [
            dropout(ques_word_embed,
                    p=self.opt['dropout_emb'],
                    training=self.drop_emb)
        ]  # batch x q_len x vocab_dim

        # contextualized embedding
        x_cemb = ques_cemb = None
        if 'BERT' in self.opt:
            x_cemb = ques_cemb = None

            if 'BERT_LINEAR_COMBINE' in self.opt:
                x_bert_output = self.Bert(x_bert, x_bert_mask, x_bert_offsets,
                                          x_single_mask)
                x_cemb_mid = self.linear_sum(x_bert_output, self.alphaBERT,
                                             self.gammaBERT)
                ques_bert_output = self.Bert(q_bert, q_bert_mask,
                                             q_bert_offsets, q_mask)
                ques_cemb_mid = self.linear_sum(ques_bert_output,
                                                self.alphaBERT, self.gammaBERT)
                x_cemb_mid = x_cemb_mid.expand(batch_size, -1, -1)
            else:
                x_cemb_mid = self.Bert(x_bert, x_bert_mask, x_bert_offsets,
                                       x_single_mask)
                x_cemb_mid = x_cemb_mid.expand(batch_size, -1, -1)
                ques_cemb_mid = self.Bert(q_bert, q_bert_mask, q_bert_offsets,
                                          q_mask)

            x_input_list.append(x_cemb_mid)
            ques_input_list.append(ques_cemb_mid)

        if 'CHAR_CNN' in self.opt:
            x_char_final = self.character_cnn(x_char, x_char_mask)
            x_char_final = x_char_final.expand(batch_size, -1, -1)
            ques_char_final = self.character_cnn(q_char, q_char_mask)
            x_input_list.append(x_char_final)
            ques_input_list.append(ques_char_final)

        x_prealign = self.pre_align(x_word_embed, ques_word_embed, q_mask)
        x_input_list.append(
            x_prealign)  # batch x x_len x (vocab_dim + cdim + vocab_dim)

        x_pos_emb = self.pos_embedding(x_pos).expand(
            batch_size, -1, -1)  # batch x x_len x pos_dim
        x_ent_emb = self.ent_embedding(x_ent).expand(
            batch_size, -1, -1)  # batch x x_len x ent_dim
        x_input_list.append(x_pos_emb)
        x_input_list.append(x_ent_emb)
        x_input_list.append(
            x_features
        )  # batch x x_len x (vocab_dim + cdim + vocab_dim + pos_dim + ent_dim + feature_dim)

        x_input = torch.cat(
            x_input_list, 2
        )  # batch x x_len x (vocab_dim + cdim + vocab_dim + pos_dim + ent_dim + feature_dim)
        ques_input = torch.cat(ques_input_list,
                               2)  # batch x q_len x (vocab_dim + cdim)

        # Multi-layer RNN
        _, x_rnn_layers = self.context_rnn(
            x_input, x_mask, return_list=True, x_additional=x_cemb
        )  # layer x batch x x_len x context_rnn_output_size
        _, ques_rnn_layers = self.ques_rnn(
            ques_input, q_mask, return_list=True, x_additional=ques_cemb
        )  # layer x batch x q_len x ques_rnn_output_size

        # rnn with question only
        ques_highlvl = self.high_lvl_ques_rnn(
            torch.cat(ques_rnn_layers, 2),
            q_mask)  # batch x q_len x high_lvl_ques_rnn_output_size
        ques_rnn_layers.append(ques_highlvl)  # (layer + 1) layers

        # deep multilevel inter-attention
        if x_cemb is None:
            x_long = x_word_embed
            ques_long = ques_word_embed
        else:
            x_long = torch.cat([x_word_embed, x_cemb],
                               2)  # batch x x_len x (vocab_dim + cdim)
            ques_long = torch.cat([ques_word_embed, ques_cemb],
                                  2)  # batch x q_len x (vocab_dim + cdim)

        x_rnn_after_inter_attn, x_inter_attn = self.deep_attn(
            [x_long],
            x_rnn_layers, [ques_long],
            ques_rnn_layers,
            x_mask,
            q_mask,
            return_bef_rnn=True)
        # x_rnn_after_inter_attn: batch x x_len x deep_attn_output_size
        # x_inter_attn: batch x x_len x deep_attn_input_size

        # deep self attention
        if x_cemb is None:
            x_self_attn_input = torch.cat(
                [x_rnn_after_inter_attn, x_inter_attn, x_word_embed], 2)
        else:
            x_self_attn_input = torch.cat(
                [x_rnn_after_inter_attn, x_inter_attn, x_cemb, x_word_embed],
                2)
            # batch x x_len x (deep_attn_output_size + deep_attn_input_size + cdim + vocab_dim)

        x_self_attn_output = self.highlvl_self_att(x_self_attn_input,
                                                   x_self_attn_input,
                                                   x_mask,
                                                   x3=x_rnn_after_inter_attn,
                                                   drop_diagonal=True)
        # batch x x_len x deep_attn_output_size

        x_highlvl_output = self.high_lvl_context_rnn(
            torch.cat([x_rnn_after_inter_attn, x_self_attn_output], 2), x_mask)
        # bach x x_len x high_lvl_context_rnn.output_size
        x_final = x_highlvl_output

        # question self attention
        ques_final = self.ques_self_attn(
            ques_highlvl, ques_highlvl, q_mask, x3=None, drop_diagonal=True
        )  # batch x q_len x high_lvl_ques_rnn_output_size

        # merge questions
        q_merge_weights = self.ques_merger(ques_final, q_mask)
        ques_merged = weighted_avg(ques_final,
                                   q_merge_weights)  # batch x ques_final_size

        # predict scores
        score_s, score_e, score_no, score_yes, score_noanswer = self.get_answer(
            x_final, ques_merged, x_mask)
        return score_s, score_e, score_no, score_yes, score_noanswer
예제 #5
0
    def forward(self, x, x_single_mask, x_char, x_char_mask, x_features, x_pos,
                x_ent, x_bert, x_bert_mask, x_bert_offsets, q, q_mask, q_char,
                q_char_mask, q_bert, q_bert_mask, q_bert_offsets, context_len):
        """
        forward()前向计算函数以BatchGen()产生的批次数据作为输入,经过编码层、交互层和输出层计算得到最终的打分结果
        :param x: [1, x_len] (word_ids)
        :param x_single_mask: [1, x_len]
        :param x_char: [1, x_len, char_len] (char_ids)
        :param x_char_mask: [1, x_len, char_len]
        :param x_features: [batch_size, x_len, feature_len] (5 if answer_span_in_context_feature 4 otherwise)
        :param x_pos: [1, x_len] (POS id)
        :param x_ent: [1, x_len] (ENT id)
        :param x_bert: [1, x_bert_token_len]
        :param x_bert_mask: [1, x_bert_token_len]
        :param x_bert_offsets: [1, x_len, 2]
        :param q: [batch, q_len] (word_ids)
        :param q_mask: [batch, q_len]
        :param q_char: [batch, q_len, char_len] (char ids)
        :param q_char_mask: [batch, q_len, char_len]
        :param q_bert: [1, q_bert_token_len]
        :param q_bert_mask: [1, q_bert_token_len]
        :param q_bert_offsets: [1, q_len, 2]
        :param context_len: number of words in context (only one per batch)
        :return:
            score_s: [batch, context_len]
            score_e: [batch, context_len]
            score_no: [batch, 1]
            score_yes: [batch, 1]
            score_noanswer: [batch, 1]
        """
        batch_size = q.shape[0]
        # 由于同一个batch中的问答共享一篇文章,x_single_mask只有一行,这里将x_single_mask重复batch_size行,与问题数据对齐
        x_mask = x_single_mask.expand(batch_size, -1)
        # 获得文章单词编码,同样重复batch_size行
        x_word_embed = self.vocab_embed(x).expand(
            batch_size, -1, -1)  # [batch, x_len, vocab_dim]
        # 获得问题单词编码
        ques_word_embed = self.vocab_embed(q)  # [batch, q_len, vocab_dim]
        # 文章单词历史
        x_input_list = [
            dropout(x=x_word_embed,
                    p=self.opt['dropout_emb'],
                    training=self.drop_emb)
        ]  # [batch, x_len, vocab_dim]
        # 问题单词历史
        ques_input_list = [
            dropout(x=x_word_embed,
                    p=self.opt['dropout_emb'],
                    training=self.drop_emb)
        ]  # [batch, q_len, vocab_dim]
        # 上下文编码层
        x_cemb = ques_cemb = None
        if 'BERT' in self.opt:
            x_cemb = ques_cemb = None

            if 'BERT_LINEAR_COMBINE' in self.opt:
                # 得到BERT每一层输出的文章单词编码
                x_bert_output = self.Bert(x_bert, x_bert_mask, x_bert_offsets,
                                          x_single_mask)
                # 计算加权和
                x_cemb_mid = self.linear_sum(x_bert_output, self.alphaBERT,
                                             self.gammaBERT)
                # 得到BERT每一层输出的问题单词编码
                ques_bert_output = self.Bert(q_bert, q_bert_mask,
                                             q_bert_offsets, q_mask)
                # 计算加权和
                ques_cemb_mid = self.linear_sum(ques_bert_output,
                                                self.alphaBERT, self.gammaBERT)
                x_cemb_mid = x_cemb_mid.expand(batch_size, -1, -1)
            else:
                # 不计算加权和的情况
                x_cemb_mid = self.Bert(x_bert, x_bert_mask, x_bert_offsets,
                                       x_single_mask)
                x_cemb_mid = x_cemb_mid.expand(batch_size, -1, -1)
                ques_cemb_mid = self.Bert(q_bert, q_bert_mask, q_bert_offsets,
                                          q_mask)

            # 上下文编码加入单词历史
            x_input_list.append(x_cemb_mid)
            ques_input_list.append(ques_cemb_mid)

        if 'CHAR_CNN' in self.opt:
            x_char_final = self.character_cnn(x_char, x_char_mask)
            x_char_final = x_char_final.expand(batch_size, -1, -1)
            ques_char_final = self.character_cnn(q_char, q_char_mask)
            x_input_list.append(x_char_final)
            ques_input_list.append(ques_char_final)

        # 单词注意力层
        x_prealign = self.pre_align(x_word_embed, ques_word_embed, q_mask)
        x_input_list.append(
            x_prealign)  # [batch, x_len, vocab_dim + cdim + vocab_dim]
        # 词性编码
        x_pos_emb = self.pos_embedding(x_pos).expand(
            batch_size, -1, -1)  # [batch, x_len, pos_dim]
        # 命名实体编码
        x_ent_emb = self.ent_embedding(x_ent).expand(
            batch_size, -1, -1)  # [batch, x_len, ent_dim]
        x_input_list.append(x_pos_emb)
        x_input_list.append(x_ent_emb)
        # 加入文章单词的词频和精确匹配特征
        x_input_list.append(
            x_features
        )  # [batch_size, x_len, vocab_dim + cdim + vocab_dim + pos_dim, ent_dim, feature_dim]
        # 将文章答案的单词历史向量拼接起来
        x_input = torch.cat(
            x_input_list, 2
        )  # [batch_size, x_len, vocab_dim + cdim + vocab_dim + pos_dim + ent_dim + feature_dim]
        # 将问题答案的单词历史向量拼接起来
        ques_input = torch.cat(ques_input_list,
                               2)  # [batch_size, q_len, vocab_dim + cdim]
        # Multi-layer RNN, 获得文章和问题RNN层的输出
        _, x_rnn_layers = self.context_rnn(
            x_input, x_mask, return_list=True, x_additional=x_cemb
        )  # [layer, batch, x_len, context_rnn_output_size]
        _, ques_rnn_layers = self.ques_rnn(
            ques_input, q_mask, return_list=True, x_additional=ques_cemb
        )  # [layer, batch, q_len, ques_rnn_output_size]
        # 问题理解层
        ques_highlvl = self.high_lvl_ques_rnn(
            torch.cat(ques_rnn_layers, 2),
            q_mask)  # [batch, q_len, high_lvl_ques_rnn_output_size]
        ques_rnn_layers.append(ques_highlvl)  # (layer + 1) layers

        # deep multilevel inter-attention, 全关注互注意力层的输入
        if x_cemb is None:
            x_long = x_word_embed
            ques_long = ques_word_embed
        else:
            x_long = torch.cat([x_word_embed, x_cemb],
                               2)  # [batch, x_len, vocab_dim + cdim]
            ques_long = torch.cat([ques_word_embed, ques_cemb],
                                  2)  # [batch, q_len, vocab_dim + cdim]
        # 文章单词经过全关注互注意力层, x_rnn_after_inter_attn: [batch, x_len, deep_attn_output_size], x_inter_attn: [batch, x_len, deep_attn_input_size]
        x_rnn_after_inter_attn, x_inter_attn = self.deep_attn(
            [x_long],
            x_rnn_layers, [ques_long],
            ques_rnn_layers,
            x_mask,
            q_mask,
            return_bef_rnn=True)

        # deep self attention, 全关注自注意力层的输入, x_self_attn_input: [batch, x_len, deep_attn_output_size + deep_attn_input_size + cdim + vocab_dim]
        if x_cemb is None:
            x_self_attn_input = torch.cat(
                [x_rnn_after_inter_attn, x_inter_attn, x_word_embed], 2)
        else:
            x_self_attn_input = torch.cat(
                [x_rnn_after_inter_attn, x_inter_attn, x_cemb, x_word_embed],
                2)
        # 文章经过全关注自注意力层
        x_self_attn_output = self.highlvl_self_attn(
            x_self_attn_input,
            x_self_attn_input,
            x_mask,
            x3=x_rnn_after_inter_attn,
            drop_diagonal=True)  # [batch, x_len, deep_attn_output_size]

        # 文章单词经过高级RNN层
        x_highlvl_output = self.high_lvl_context_rnn(
            torch.cat([x_rnn_after_inter_attn, x_self_attn_output], 2), x_mask)

        # 文章单词的最终编码x_final
        x_final = x_highlvl_output  # [batch, x_len, high_lvl_context_rnn_output_size]

        # 问题单词的自注意力层
        ques_final = self.ques_self_attn(
            ques_highlvl, ques_highlvl, q_mask, x3=None, drop_diagonal=True
        )  # [batch, q_len, high_lvl_ques_rnn_output_size]

        # merge questions, 获得问题的向量表示
        q_merge_weights = self.ques_merger(ques_final, q_mask)
        ques_merged = weighted_avg(
            ques_final, q_merge_weights
        )  # [batch, ques_final_size], 按照q_merge_weights计算ques_final的加权和

        # 获得答案在文章每个位置开始和结束的概率以及三种特殊答案“是/否/没有答案”的概率
        score_s, score_e, score_no, score_yes, score_noanswer = self.get_answer(
            x_final, ques_merged, x_mask)

        return score_s, score_e, score_no, score_yes, score_noanswer