예제 #1
0
    def forward(self, x):
        mask = compute_mask(x)

        tmp_emb = self.embedding_layer.forward(x)
        out_emb = tmp_emb.transpose(0, 1)

        return out_emb, mask
예제 #2
0
    def collect_fun(self, batch):
        docs = []
        labels = []

        for ele in batch:
            docs.append(ele[0])
            labels.append(ele[1])

        docs = torch.stack(docs, dim=0)
        labels = torch.stack(labels, dim=0)

        # compress on word level
        docs, _ = del_zeros_right(docs)
        docs_mask = compute_mask(docs, padding_idx=Vocabulary.PAD_IDX)

        # compress on sentence level
        if self.hierarchical:
            _, sent_right_idx = del_zeros_right(docs_mask.sum(-1))
            docs = docs[:, :sent_right_idx, :]
            docs_mask = docs_mask[:, :sent_right_idx, :]

        # logger.info('tar_d: {}, {}'.format(docs.dtype, docs.shape))
        # logger.info('label: {}, {}'.format(labels.dtype, labels.shape))

        return docs, docs_mask, labels
예제 #3
0
 def forward(self, x):
     # mask 的作用就是生成一个和x维度一样的矩阵,将x中有单词的地方置位1,padding的地方置位0x
     mask = compute_mask(x, PreprocessData.padding_idx)
     tmp_emb = self.embedding_layer.forward(x)
     # 将embed的tensor变成(sequence_len,batch_size,embedding)的样子
     out_emb = tmp_emb
     return out_emb, mask
예제 #4
0
    def forward(self, x):
        batch_size, seq_len, word_len = x.shape
        x = x.view(-1, word_len)

        mask = compute_mask(x, 0)  # char-level padding idx is zero
        x_emb = self.embedding_layer.forward(x)
        x_emb = x_emb.view(batch_size, seq_len, word_len, -1)
        mask = mask.view(batch_size, seq_len, word_len)

        return x_emb, mask
예제 #5
0
    def forward(self, tar_d, cand_ds):
        tar_d, _ = del_zeros_right(tar_d)
        cand_ds, _ = del_zeros_right(cand_ds)

        if self.doc_hierarchical:
            _, sent_right_idx = del_zeros_right(tar_d.sum(-1))
            tar_d = tar_d[:, :sent_right_idx, :]

            _, sent_right_idx = del_zeros_right(cand_ds.sum(-1))
            cand_ds = cand_ds[:, :, :sent_right_idx, :]

        # embedding layer
        tar_doc_emb = self.embedding_layer(tar_d)
        tar_doc_mask = compute_mask(tar_d)

        cand_docs_emb = self.embedding_layer(cand_ds)
        cand_docs_mask = compute_mask(cand_ds)

        # target document encoder layer
        tar_doc_rep, _ = self.tar_doc_encoder(tar_doc_emb, tar_doc_mask)

        # candidate documents encoder layer
        batch, cand_doc_num = cand_docs_emb.size(0), cand_docs_emb.size(1)
        new_size = [batch * cand_doc_num] + list(cand_docs_emb.shape[2:])
        cand_docs_emb_flip = cand_docs_emb.view(*new_size)

        new_size = [batch * cand_doc_num] + list(cand_ds.shape[2:])
        cand_docs_mask_flip = cand_docs_mask.view(*new_size)

        cand_docs_rep_flip, _ = self.cand_doc_encoder(cand_docs_emb_flip,
                                                      cand_docs_mask_flip)
        cand_docs_rep = cand_docs_rep_flip.contiguous().view(
            batch, cand_doc_num, -1)

        # output layer
        cand_scores = torch.bmm(tar_doc_rep.unsqueeze(1),
                                cand_docs_rep.transpose(1, 2)).squeeze(
                                    1)  # (batch, cand_doc_num)
        cand_logits = torch.log_softmax(cand_scores, dim=-1)

        return cand_logits
예제 #6
0
    def collect_fun(self, batch):
        tar_d = []
        cand_ds = []
        pt_label = []
        gt_idx = []

        for ele in batch:
            tar_d.append(ele[0])
            cand_ds.append(ele[1])
            pt_label.append(ele[2])
            gt_idx.append(ele[3])

        tar_d = torch.stack(tar_d, dim=0)
        cand_ds = torch.stack(cand_ds, dim=0)
        pt_label = torch.stack(pt_label, dim=0)
        gt_idx = torch.tensor(gt_idx, dtype=torch.long)

        # compress on word level
        tar_d, _ = del_zeros_right(tar_d)
        tar_mask = compute_mask(tar_d, padding_idx=Vocabulary.PAD_IDX)

        cand_ds, _ = del_zeros_right(cand_ds)
        cand_mask = compute_mask(cand_ds, padding_idx=Vocabulary.PAD_IDX)

        # compress on sentence level
        if self.hierarchical:
            _, sent_right_idx = del_zeros_right(tar_mask.sum(-1))
            tar_d = tar_d[:, :sent_right_idx, :]
            tar_mask = tar_mask[:, :sent_right_idx, :]

            _, sent_right_idx = del_zeros_right(cand_mask.sum(-1))
            cand_ds = cand_ds[:, :, :sent_right_idx, :]
            cand_mask = cand_mask[:, :, :sent_right_idx, :]

        # logger.info('tar_d: {}, {}'.format(tar_d.dtype, tar_d.shape))
        # logger.info('cand_ds: {}, {}'.format(cand_ds.dtype, cand_ds.shape))
        # logger.info('label: {}, {}'.format(pt_label.dtype, pt_label.shape))

        return tar_d, tar_mask, cand_ds, cand_mask, pt_label, gt_idx
예제 #7
0
    def forward(self, docs_rep, turn_nl):
        turn_nl, _ = del_zeros_right(turn_nl)

        # 1. user response encoding
        turn_emb = self.embedding_layer(turn_nl)
        turn_mask = compute_mask(turn_nl)

        # (batch, hidden_size * 2)
        turn_word_rep, turn_rep = self.turn_rnn(turn_emb, turn_mask)

        # 1.1 user response classification
        # (batch, num_slots)
        turn_slot_cls = self.turn_cls_linear(turn_rep)
        turn_slot_cls = torch.softmax(turn_slot_cls, dim=-1)

        # (batch, 1)
        turn_inform_sig = self.turn_inform_linear(turn_rep)
        turn_inform_sig = torch.sigmoid(turn_inform_sig)

        # (batch, num_slots + 1)
        turn_slot_inform_cls = torch.cat([(1 - turn_inform_sig) * turn_slot_cls,
                                          turn_inform_sig], dim=-1)

        # 2. compute similar with bi-linear layer
        batch, num_docs, num_slots, _ = docs_rep.size()

        # (batch, num_docs, num_slots, hidden_size * 2)
        turn_rep_expand = turn_rep.unsqueeze(1).unsqueeze(1).expand(-1, num_docs, num_slots, -1).contiguous()

        # (batch, num_docs, num_slots)
        turn_docs_slots_similar = self.similar_bilinear(docs_rep, turn_rep_expand).squeeze(-1)
        extra_similar = turn_docs_slots_similar.new_ones(batch, num_docs, 1)

        # 2.1 similar with slots weighted
        # (batch, num_docs, num_slots + 1)
        turn_docs_slots_similar_extra = torch.cat([turn_docs_slots_similar, extra_similar], dim=-1)
        # (batch * num_docs, num_slots + 1)
        turn_docs_slots_similar_extra = turn_docs_slots_similar_extra.reshape(batch * num_docs, -1)

        # (batch, num_docs, num_slots + 1)
        turn_slot_inform_cls_expand = turn_slot_inform_cls.unsqueeze(1).expand(-1, num_docs, -1)
        # (batch * num_docs, num_slots + 1)
        turn_slot_inform_cls_expand = turn_slot_inform_cls_expand.reshape(batch * num_docs, -1)

        # (batch * num_docs)
        turn_docs_similar = torch.bmm(turn_docs_slots_similar_extra.unsqueeze(1),
                                      turn_slot_inform_cls_expand.unsqueeze(2)).squeeze(1).squeeze(1)
        # (batch, num_docs)
        turn_level_docs_dist = torch.softmax(turn_docs_similar.view(batch, num_docs), dim=-1)

        return turn_level_docs_dist, turn_slot_cls, turn_inform_sig
예제 #8
0
    def forward(self, doc):
        doc, _ = del_zeros_right(doc)
        _, sent_right_idx = del_zeros_right(doc.sum(-1))
        doc = doc[:, :sent_right_idx, :]

        # embedding layer
        doc_emb = self.embedding_layer(doc)
        doc_mask = compute_mask(doc)

        # doc encoder layer
        tar_doc_rep, _ = self.tar_doc_encoder(doc_emb, doc_mask)
        cand_doc_rep, _ = self.cand_doc_encoder(doc_emb, doc_mask)

        # doc representation
        doc_rep = torch.cat([tar_doc_rep, cand_doc_rep], dim=-1)

        return doc_rep
    def forward(self, contents, question_ans, logics, questions, answers):
        # contents size(16,10,200)
        # question_ans size(16,100)
        batch_size = question_ans.size()[0]
        max_content_len = contents.size()[2]
        max_question_len = question_ans.size()[1]
        contents_num = contents.size()[1]

        # word-level embedding: (seq_len, batch, embedding_size)
        question_vec, question_mask = self.fixed_embedding.forward(
            question_ans)  # size=(batch, 100, 200)
        delta_emb = self.delta_embedding.forward(
            question_ans)  # size=(batch, 100, 200)
        question_vec += delta_emb  # size=(batch, 100, 200)
        question_encode = self.context_layer.forward(
            question_vec, question_mask)  # size=(batch, 100, 256)

        contents = contents[:, 0:self.
                            use_content_nums, :]  # use top n contents # size=(batch, n, 200)
        viewed_contents = contents.contiguous().view(
            batch_size * self.use_content_nums,
            max_content_len)  # size=(batch*n, 200)
        viewed_content_vec, viewed_content_mask = self.fixed_embedding.forward(
            viewed_contents)  # size=(batch*n, 200,256)
        viewed_delta_emb = self.delta_embedding.forward(
            viewed_contents)  # size=(batch*n, 200,256)
        viewed_content_vec += viewed_delta_emb
        viewed_content_encode = self.context_layer.forward(
            viewed_content_vec,
            viewed_content_mask)  # size=(batch*n, 200, 256)
        content_encode = viewed_content_encode.view(
            batch_size, self.use_content_nums, max_content_len,
            self.hidden_size * 2).transpose(0, 1)  # size=(n,batch, 200, 256)

        viewed_reasoning_content_gating_val = self.reasoning_gating_layer(
            viewed_content_encode)  # size=(16*n,200,1)
        reasoning_question_gating_val = self.reasoning_gating_layer(
            question_encode)  # size=(16,100,1)

        # matching features
        matching_feature_row = []  # list[tensor(16,200,2)]
        matching_feature_col = []  # list[tensor(16,100,2)]
        # compute RnQ & RnD
        RnQ = []  # list[tensor[16,100,256]]
        RnD = []  # list[tensor[16,200,256]]
        D_RnD = []  # 获得D和RnD的concatenation
        for i in range(self.use_content_nums):
            cur_content_encode = content_encode[
                i]  # size=(16,200,256) 当前的第i个content
            cur_Matching_matrix = self.compute_matching_matrix(
                question_encode, cur_content_encode
            )  # (batch, question_len , content_len) eg(16,100,200)

            cur_an_matrix = torch.nn.functional.softmax(
                cur_Matching_matrix, dim=2
            )  # column wise softmax,对matching matrix每一行归一化和为1 size=(batch, question_len , content_len)
            cur_bn_matrix = torch.nn.functional.softmax(
                cur_Matching_matrix, dim=1
            )  # row_wise attention,对matching matrix每一列归一化和为1 size=(batch, question_len , content_len)

            cur_RnQ = self.compute_RnQ(
                cur_an_matrix, cur_content_encode
            )  # size=(batch, 100 , 256)     eg[16,100,256]
            cur_RnD = self.compute_RnD(cur_bn_matrix,
                                       question_encode)  # size=[16,200,256]
            cur_D_RnD = torch.cat([cur_content_encode, cur_RnD],
                                  dim=2)  # size=(16,200,512)
            D_RnD.append(cur_D_RnD)
            RnQ.append(cur_RnQ)
            RnD.append(cur_RnD)

            # 计算matching feature
            cur_max_pooling_feature_row, _ = torch.max(cur_Matching_matrix,
                                                       dim=1)  # size=(16,200)
            cur_mean_pooling_feature_row = torch.mean(cur_Matching_matrix,
                                                      dim=1)  # size=(16,200)
            cur_matching_feature_row = torch.stack(
                [cur_max_pooling_feature_row, cur_mean_pooling_feature_row],
                dim=-1)  # size=(16,200,2)
            matching_feature_row.append(cur_matching_feature_row)

            cur_max_pooling_feature_col, _ = torch.max(cur_Matching_matrix,
                                                       dim=2)  # size=(16,100)
            cur_mean_pooling_feature_col = torch.mean(cur_Matching_matrix,
                                                      dim=2)  # size=(16,100)
            cur_matching_feature_col = torch.stack(
                [cur_max_pooling_feature_col, cur_mean_pooling_feature_col],
                dim=-1)  # size=(16,100,2)
            matching_feature_col.append(cur_matching_feature_col)

        RmD = []  # list[tensor(16,200,512)]
        for i in range(self.use_content_nums):
            D_RnD_m = D_RnD[i]  # size=(16,200,512)
            Mmn_i = []
            RmD_i = []
            for j in range(self.use_content_nums):
                D_RnD_n = D_RnD[j]  # size=(16,200,512)
                # 计算任意两个文档之间的attention Mmn_i_j size=(16,200,200)
                Mmn_i_j = self.compute_cross_document_attention(
                    D_RnD_m, D_RnD_n)
                Mmn_i.append(Mmn_i_j)

            Mmn_i = torch.stack(Mmn_i, dim=-1)  # size=(16,200,200,10)
            softmax_Mmn_i = self.reduce_softmax(Mmn_i)  # size=(16,200,200,10)

            for j in range(self.use_content_nums):
                D_RnD_n = D_RnD[j]  # size=(16,200,512)
                beta_mn_i_j = softmax_Mmn_i[:, :, :, j]  # size=(16,200,200)
                cur_RmD = torch.bmm(beta_mn_i_j, D_RnD_n)  # size=(16,200,512)
                RmD_i.append(cur_RmD)

            RmD_i = torch.stack(RmD_i, dim=1)  # size=(16,10,200,512)
            RmD_i = torch.sum(RmD_i, dim=1)  # size=(16,200,512)
            RmD.append(RmD_i)

        RnQ = torch.stack(RnQ, dim=1)  # size 16,n,100,256
        viewed_RnQ = RnQ.view(batch_size * self.use_content_nums,
                              max_question_len,
                              self.hidden_size * 2)  # size 16*n,100,256

        RmD = torch.stack(RmD, dim=1)  # size 16,n,200,256
        viewed_RmD = RmD.view(batch_size * self.use_content_nums,
                              max_content_len,
                              self.hidden_size * 4)  # size 16*n,200,512

        viewed_matching_feature_col = torch.stack(
            matching_feature_col,
            dim=1).view(batch_size * self.use_content_nums, max_question_len,
                        2)  # size 16*n,100,2
        viewed_matching_feature_row = torch.stack(
            matching_feature_row,
            dim=1).view(batch_size * self.use_content_nums, max_content_len,
                        2)  # size 16*n,200,2

        viewed_RnQ = torch.cat([viewed_RnQ, viewed_matching_feature_col],
                               dim=-1)  # size 16*n,100,258
        viewed_RmD = torch.cat([viewed_RmD, viewed_matching_feature_row],
                               dim=-1)  # size 16*n,200,514

        viewed_RnQ_mask = compute_mask(
            viewed_RnQ.mean(dim=2),
            PreprocessData.padding_idx)  # size 16*n,100,258
        viewed_RmD_mask = compute_mask(
            viewed_RmD.mean(dim=2),
            PreprocessData.padding_idx)  # size 16*n,200,514
        viewed_reasoning_question_gating_val = reasoning_question_gating_val.unsqueeze(1) \
            .repeat(1, self.use_content_nums, 1, 1) \
            .view(batch_size * self.use_content_nums, max_question_len, 1)  # size 16*n,100,1

        gated_cur_RnQ = self.compute_gated_value(
            viewed_RnQ, viewed_reasoning_question_gating_val)
        gated_cur_RmD = self.compute_gated_value(
            viewed_RmD, viewed_reasoning_content_gating_val)

        # 经过reasoning层
        cur_RnQ_reasoning_out = self.question_reasoning_layer.forward(
            gated_cur_RnQ, viewed_RnQ_mask)  # size=(16*n,100,256)
        cur_RmD_reasoning_out = self.content_reasoning_layer.forward(
            gated_cur_RmD, viewed_RmD_mask)  # size=(16*n,200,256)

        # reasoning 层后的pooling 100降维
        RnQ_reasoning_out_max_pooling, _ = torch.max(cur_RnQ_reasoning_out,
                                                     dim=1)  # size=(16*n,256)
        RmD_reasoning_out_max_pooling, _ = torch.max(cur_RmD_reasoning_out,
                                                     dim=1)  # size(16*n,256)
        viewed_reasoning_feature = torch.cat(
            [RnQ_reasoning_out_max_pooling, RmD_reasoning_out_max_pooling],
            dim=1)  # size(16*n,512)
        reasoning_feature = viewed_reasoning_feature.view(
            batch_size, self.use_content_nums,
            self.hidden_size * 4)  # size=(16,10,512)
        noise_gate_val = self.decision_gating_layer(
            reasoning_feature)  # size=(16,10,1)
        gated_reasoning_feature = self.compute_gated_value(
            reasoning_feature, noise_gate_val)  # size=(16,10,512)

        reasoning_out_max_feature, _ = torch.max(gated_reasoning_feature,
                                                 dim=1)  # size=(16,512)
        reasoning_out_mean_feature = torch.mean(gated_reasoning_feature,
                                                dim=1)  # size=(16,512)
        decision_input = torch.cat(
            [reasoning_out_max_feature, reasoning_out_mean_feature],
            dim=1)  # size(16,1024)

        # self attention feature
        qa_attention = self.self_attention_layer.forward(
            question_encode)  # size=(16,100,50)
        qa_attention = torch.nn.functional.softmax(qa_attention,
                                                   dim=1)  # size=(16,100,50)
        qa_match = torch.bmm(question_encode.transpose(1, 2),
                             qa_attention)  #size=(16,256,50)
        qa_matching_feature = torch.sum(qa_match, dim=1)  #size=(16,50)

        decision_input = torch.cat([decision_input, qa_matching_feature],
                                   dim=1)  # size(16,1024+500)

        decision_output = self.decision_layer.forward(
            decision_input)  # size=(16,1)
        logics = logics.resize_(logics.size()[0], 1)
        output = decision_output * logics

        output = output.view(int(output.size()[0] / 5), 5)
        return output  # logics 是反向的话乘以-1,正向的话是乘以1
    def forward(self, contents, question_ans, logics, contents_char=None, question_ans_char=None):

        # assert contents_char is not None and question_ans_char is not None
        batch_size = question_ans.size()[0]
        max_content_len = contents.size()[2]
        max_question_len = question_ans.size()[1]
        contents_num = contents.size()[1]
        # word-level embedding: (seq_len, batch, embedding_size)
        content_vec = []
        content_mask = []
        question_vec, question_mask = self.embedding.forward(question_ans)
        for i in range(contents_num):
            cur_content = contents[:, i, :]
            cur_content_vec, cur_content_mask = self.embedding.forward(cur_content)
            content_vec.append(cur_content_vec)
            content_mask.append(cur_content_mask)

        # char-level embedding: (seq_len, batch, char_embedding_size)
        # context_emb_char, context_char_mask = self.char_embedding.forward(context_char)
        # question_emb_char, question_char_mask = self.char_embedding.forward(question_char)
        question_encode, _ = self.context_layer.forward(question_vec,question_mask)  # size=(cur_batch_max_questionans_len, batch, 256)
        content_encode = []  # word-level encode: (seq_len, batch, hidden_size)
        for i in range(contents_num):
            cur_content_vec = content_vec[i]
            cur_content_mask = content_mask[i]
            cur_content_encode, _ = self.context_layer.forward(cur_content_vec,cur_content_mask)  # size=(cur_batch_max_content_len, batch, 256)
            content_encode.append(cur_content_encode)

        # 将所有的content编码后统一到相同的长度 200,所有的question编码后统一到相同的长度100
        same_sized_content_encode = []
        for i in range(contents_num):
            cur_content_encode = content_encode[i]
            cur_content_encode = self.full_matrix_to_specify_size(cur_content_encode, [max_content_len, batch_size,cur_content_encode.size()[2]])  # size=(200,16,256)
            same_sized_content_encode.append(cur_content_encode)
        same_sized_question_encode = self.full_matrix_to_specify_size(question_encode, [max_question_len, batch_size,question_encode.size()[2]])  # size=(100,16,256)


        # 计算gating layer的值
        reasoning_content_gating_val = []
        reasoning_question_gating_val = None
        decision_content_gating_val = []
        decision_question_gating_val = None
        for i in range(contents_num):
            cur_content_encode = same_sized_content_encode[i]  # size=(200,16,256)
            cur_gating_input = cur_content_encode.permute(1,2,0)  # size=(16,256,200)
            cur_reasoning_content_gating_val = self.reasoning_gating_layer(cur_gating_input)  # size=(16,1,200)
            cur_reasoning_content_gating_val =cur_reasoning_content_gating_val+0.00001 # 防止出现gate为0的情况,导致后面padsequence的时候出错

            cur_decision_content_gating_val = self.decision_gating_layer(cur_gating_input)  # size=(16,1,200)
            cur_decision_content_gating_val =cur_decision_content_gating_val+0.00001 # 防止出现gate为0的情况,导致后面padsequence的时候出错
            reasoning_content_gating_val.append(cur_reasoning_content_gating_val)
            decision_content_gating_val.append(cur_decision_content_gating_val)

        question_gating_input = same_sized_question_encode.permute(1,2,0)  # size=(16,256,100)
        reasoning_question_gating_val = self.reasoning_gating_layer(question_gating_input)  # size=(16,1,100)
        reasoning_question_gating_val=reasoning_question_gating_val+0.00001 # 防止出现gate为0的情况,导致后面padsequence的时候出错
        decision_question_gating_val = self.decision_gating_layer(question_gating_input)  # size=(16,1,100)
        decision_question_gating_val=decision_question_gating_val+0.00001 # 防止出现gate为0的情况,导致后面padsequence的时候出错


        # 计算gate loss todo: 貌似无法返回多个变量,暂时无用
        # question_gate_val = torch.cat([reasoning_question_gating_val.view(-1), decision_question_gating_val.view(-1)])
        # reasoning_gate_val = torch.cat([ele.view(-1) for ele in reasoning_content_gating_val])
        # decision_gate_val = torch.cat([ele.view(-1) for ele in decision_content_gating_val])
        # all_gate_val = torch.cat([question_gate_val, reasoning_gate_val, decision_gate_val])
        # mean_gate_val = torch.mean(all_gate_val)


        # Matching Matrix computing, question 和每一个content都要计算matching matrix
        Matching_matrix = []
        for i in range(contents_num):
            cur_content_encode = same_sized_content_encode[i]
            cur_Matching_matrix = self.compute_matching_matrix(same_sized_question_encode,
                                                               cur_content_encode)  # (batch, question_len , content_len) eg(16,100,200)
            Matching_matrix.append(cur_Matching_matrix)

        # compute an & bn
        an_matrix = []
        bn_matrix = []
        for i in range(contents_num):
            cur_Matching_matrix = Matching_matrix[i]
            cur_an_matrix = torch.nn.functional.softmax(cur_Matching_matrix,dim=2)  # column wise softmax,对matching matrix每一行归一化和为1 size=(batch, question_len , content_len)
            cur_bn_matrix = torch.nn.functional.softmax(cur_Matching_matrix,dim=1)  # row_wise attention,对matching matrix每一列归一化和为1 size=(batch, question_len , content_len)
            an_matrix.append(cur_an_matrix)
            bn_matrix.append(cur_bn_matrix)


        # compute RnQ & RnD
        RnQ = []  # list[tensor[16,100,256]]
        RnD = []
        for i in range(contents_num):
            cur_an_matrix = an_matrix[i]
            cur_content_encode = same_sized_content_encode[i]
            cur_bn_matrix = bn_matrix[i]
            cur_RnQ = self.compute_RnQ(cur_an_matrix, cur_content_encode)  # size=(batch, curbatch_max_question_len , 256)     eg[16,100,256]
            cur_RnD = self.compute_RnD(cur_bn_matrix,same_sized_question_encode)  # size=(batch, curbatch_max_content_len , 256)    eg[16,200,256]
            RnQ.append(cur_RnQ)
            RnD.append(cur_RnD)


        ########### compute Mmn' ##############
        D_RnD = []  # 先获得D和RnD的concatenation
        for i in range(contents_num):
            cur_content_encode = same_sized_content_encode[i].transpose(0, 1)  # size=(16,200,256)
            cur_RnD = RnD[i]  # size=(16,200,256)
            # embed()
            cur_D_RnD = torch.cat([cur_content_encode, cur_RnD], dim=2)  # size=(16,200,512)
            D_RnD.append(cur_D_RnD)

        RmD = []  # list[tensor(16,200,512)]
        for i in range(contents_num):
            D_RnD_m = D_RnD[i]  # size=(16,200,512)
            Mmn_i=[]
            RmD_i = []
            for j in range(contents_num):
                D_RnD_n = D_RnD[j]  # size=(16,200,512)
                Mmn_i_j = self.compute_cross_document_attention(D_RnD_m,D_RnD_n)  # 计算任意两个文档之间的attention Mmn_i_j size=(16,200,200)
                Mmn_i.append(Mmn_i_j)

            Mmn_i=torch.stack(Mmn_i).permute(1,2,3,0)# size=(16,200,200,10)
            softmax_Mmn_i=self.reduce_softmax(Mmn_i) # size=(16,200,200,10)

            for j in range(contents_num):
                D_RnD_n = D_RnD[j]  # size=(16,200,512)
                beta_mn_i_j = softmax_Mmn_i[:,:,:,j]
                cur_RmD = torch.bmm(beta_mn_i_j, D_RnD_n)  # size=(16,200,512)
                RmD_i.append(cur_RmD)


            RmD_i = torch.stack(RmD_i)  # size=(10,16,200,512)
            RmD_i = RmD_i.transpose(0, 1)  # size=(16,10,200,512)
            RmD_i = torch.sum(RmD_i, dim=1)  # size=(16,200,512)
            RmD.append(RmD_i)

        # RmD=torch.stack(RmD).transpose(0,1) #size=(16,10,200,512)


        matching_feature_row = []  # list[tensor(16,200,2)]
        matching_feature_col = []  # list[tensor(16,100,2)]
        for i in range(contents_num):
            cur_Matching_matrix = Matching_matrix[i]  # size=(16,100,200)
            cur_max_pooling_feature_row, _ = torch.max(cur_Matching_matrix, dim=1)  # size=(16,200)
            cur_mean_pooling_feature_row = torch.mean(cur_Matching_matrix, dim=1)  # size=(16,200)
            cur_matching_feature_row = torch.stack([cur_max_pooling_feature_row, cur_mean_pooling_feature_row]).permute(1,2,0)  # size=(16,200,2)
            matching_feature_row.append(cur_matching_feature_row)

            cur_max_pooling_feature_col, _ = torch.max(cur_Matching_matrix, dim=2)  # size=(16,100)
            cur_mean_pooling_feature_col = torch.mean(cur_Matching_matrix, dim=2)  # size=(16,100)
            cur_matching_feature_col = torch.stack([cur_max_pooling_feature_col, cur_mean_pooling_feature_col]).permute(1,2,0)  # size=(16,100,2)
            matching_feature_col.append(cur_matching_feature_col)
        # print(253)
        # embed()
        reasoning_feature = []
        RnQ_reasoning_out=[]
        RmD_reasoning_out=[]
        for i in range(contents_num):
            cur_RnQ = RnQ[i]  # size=(16,100,256)
            cur_RmD = RmD[i]  # size=(16,200,512)
            cur_matching_feature_col = matching_feature_col[i]  # size=(16,100,2)
            cur_matching_feature_row = matching_feature_row[i]  # size=(16,200,2)

            cur_RnQ = torch.cat([cur_RnQ, cur_matching_feature_col], dim=2)  # size=(16,100,258)
            cur_RmD = torch.cat([cur_RmD, cur_matching_feature_row], dim=2)  # size=(16,200,514)

            cur_RnQ_mask = compute_mask(cur_RnQ.mean(dim=2), PreprocessData.padding_idx)
            cur_RmD_mask = compute_mask(cur_RmD.mean(dim=2), PreprocessData.padding_idx)

            gated_cur_RnQ=self.compute_gated_value(cur_RnQ,reasoning_question_gating_val)# size=(16,100,258)
            gated_cur_RmD=self.compute_gated_value(cur_RmD,reasoning_content_gating_val[i])# size=(16,200,514)

            # 经过reasoning层
            cur_RnQ_reasoning_out, _ = self.question_reasoning_layer.forward(gated_cur_RnQ.transpose(0,1),cur_RnQ_mask)  # size=(max_sequence_len,16,256)
            cur_RmD_reasoning_out, _ = self.content_reasoning_layer.forward(gated_cur_RmD.transpose(0,1),cur_RmD_mask)  # size=(max_sequence_len,16,256)

            # 所有的矩阵变成相同的大小
            cur_RnQ_reasoning_out = self.full_matrix_to_specify_size(cur_RnQ_reasoning_out,
                                                                     [max_question_len, batch_size,
                                                                      cur_RnQ_reasoning_out.size()[2]])  # size=(100,16,256)
            cur_RmD_reasoning_out = self.full_matrix_to_specify_size(cur_RmD_reasoning_out,
                                                                     [max_content_len, batch_size,
                                                                      cur_RmD_reasoning_out.size()[2]])  # size=(200,16,256)

            #过decision layer的gate层
            cur_RnQ_reasoning_out=cur_RnQ_reasoning_out.transpose(0,1) #size(16,100,256)
            cur_RmD_reasoning_out=cur_RmD_reasoning_out.transpose(0,1) #size(16,200,256)

            RnQ_reasoning_out.append(cur_RnQ_reasoning_out)
            RmD_reasoning_out.append(cur_RmD_reasoning_out)

            # gated_RnQ_out=self.compute_gated_value(cur_RnQ_reasoning_out,decision_question_gating_val)#size(16,100,256)
            # gated_RmD_out=self.compute_gated_value(cur_RmD_reasoning_out,decision_content_gating_val[i])#size(16,200,256)
            #
            # # 将2种feature cat到一起得到300*256的表示
            # cur_reasoning_feature = torch.cat([gated_RnQ_out, gated_RmD_out], dim=1)  # size(16,300,256) ||  when content=100 size(16,200,256)
            # reasoning_feature.append(cur_reasoning_feature)

        # 10个文档的cat到一起
        RnQ_reasoning_out=torch.stack(RnQ_reasoning_out).transpose(0,1) #size=(16,10,100,256)
        RmD_reasoning_out=torch.stack(RmD_reasoning_out).transpose(0,1) #size=(16,10,200,256)

        RnQ_reasoning_out_maxpool,_=torch.max(RnQ_reasoning_out,dim=1) #size=(16,100,256)
        RmD_reasoning_out_maxpool,_=torch.max(RmD_reasoning_out,dim=1) #size=(16,200,256)

        # gated_RnQ_reasoning_out_maxpool=self.compute_gated_value(RnQ_reasoning_out_maxpool,decision_question_gating_val) #size(16,100,256)
        # gated_RmD_reasoning_out_maxpool=self.compute_gated_value(RmD_reasoning_out_maxpool,decision_content_gating_val)

        fc_input_RnQ_maxpool,_=torch.max(RnQ_reasoning_out_maxpool,dim=1) # size(16,256)
        fc_input_RnQ_meanpool=torch.mean(RnQ_reasoning_out_maxpool,dim=1) # size(16,256)

        fc_input_RmD_maxpool,_=torch.max(RmD_reasoning_out_maxpool,dim=1)#size(16,256)
        fc_input_RmD_meanpool=torch.mean(RmD_reasoning_out_maxpool,dim=1) #size(16,256)

        fc_input=torch.cat([fc_input_RnQ_maxpool,fc_input_RnQ_meanpool,fc_input_RmD_maxpool,fc_input_RmD_meanpool],dim=1) #size(16,1024)

        # reasoning_feature = torch.cat(reasoning_feature, dim=1)  # size=(16,3000,256)    |  when content=100 size(16,2000,256)
        # # print(299)
        # # embed()
        # maxpooling_reasoning_feature_column, _ = torch.max(reasoning_feature, dim=1)  # size(16,256)
        # meanpooling_reasoning_feature_column = torch.mean(reasoning_feature, dim=1)  # size(16,256)
        #
        # maxpooling_reasoning_feature_row, _ = torch.max(reasoning_feature, dim=2)  # size=(16,3000)    |  when content=100 size(16,2000)
        # meanpooling_reasoning_feature_row = torch.mean(reasoning_feature, dim=2)  # size=(16,3000)      |  when content=100 size(16,2000)
        # print(228, "============================")

        # pooling_reasoning_feature = torch.cat([maxpooling_reasoning_feature_row, meanpooling_reasoning_feature_row, maxpooling_reasoning_feature_column,meanpooling_reasoning_feature_column], dim=1)
        decision_input=fc_input.view(int(batch_size/5), 5120)  # size=(16,1024*5) 五分类问题
        #
        # print(312)
        # embed()
        output = self.decision_layer.forward(decision_input)  # size=(batchsize/5,5)



        # temp_gate_val=torch.stack([mean_gate_val,torch.tensor(0.0).to(self.device)]).resize_(1,2)
        # output_with_gate_val=torch.cat([output,temp_gate_val],dim=0)
        # logics=logics.resize_(logics.size()[0],1)
        return output # logics 是反向的话乘以-1,正向的话是乘以1
    def forward(self, contents, question_ans, logics, contents_char=None, question_ans_char=None):
        # assert contents_char is not None and question_ans_char is not None
        batch_size = question_ans.size()[0]
        max_content_len = contents.size()[2]
        max_question_len = question_ans.size()[1]
        contents_num = contents.size()[1]
        # word-level embedding: (seq_len, batch, embedding_size)
        content_vec = []
        content_mask = []
        question_vec, question_mask = self.fixed_embedding.forward(question_ans)
        delta_emb = self.delta_embedding.forward(question_ans)
        question_vec+=delta_emb
        question_encode = self.context_layer.forward(question_vec, question_mask)  # size=(batch, 100, 256)
        content_encode = []  # word-level encode: (seq_len, batch, hidden_size)
        for i in range(contents_num):
            cur_content = contents[:, i, :]
            cur_content_vec, cur_content_mask = self.fixed_embedding.forward(cur_content)
            delta_emb = self.delta_embedding.forward(cur_content)
            cur_content_vec+=delta_emb
            cur_content_encode = self.context_layer.forward(cur_content_vec, cur_content_mask)  # size=(batch, 200, 256)
            content_encode.append(cur_content_encode)

        # 计算gating layer的值
        reasoning_content_gating_val = []
        reasoning_question_gating_val = None
        for i in range(contents_num):
            cur_content_encode = content_encode[i]  # size=(16,200,256)
            cur_reasoning_content_gating_val = self.reasoning_gating_layer(cur_content_encode)  # size=(16,200,1)
            cur_reasoning_content_gating_val =cur_reasoning_content_gating_val+(1e-8) # 防止出现gate为0的情况,导致后面padsequence的时候出错
            reasoning_content_gating_val.append(cur_reasoning_content_gating_val)
        reasoning_question_gating_val = self.reasoning_gating_layer(question_encode)  # size=(16,100,1)
        reasoning_question_gating_val=reasoning_question_gating_val+(1e-8) # 防止出现gate为0的情况,导致后面padsequence的时候出错

        # 计算gate loss todo: 貌似无法返回多个变量,暂时无用
        # question_gate_val = torch.cat([reasoning_question_gating_val.view(-1), decision_question_gating_val.view(-1)])
        # reasoning_gate_val = torch.cat([ele.view(-1) for ele in reasoning_content_gating_val])
        # decision_gate_val = torch.cat([ele.view(-1) for ele in decision_content_gating_val])
        # all_gate_val = torch.cat([question_gate_val, reasoning_gate_val, decision_gate_val])
        # mean_gate_val = torch.mean(all_gate_val)


        # Matching Matrix computing, question 和每一个content都要计算matching matrix

        #matching features
        matching_feature_row = []  # list[tensor(16,200,2)]
        matching_feature_col = []  # list[tensor(16,100,2)]
        # compute RnQ & RnD
        RnQ = []  # list[tensor[16,100,256]]
        RnD = []  # list[tensor[16,200,256]]
        D_RnD = []  # 获得D和RnD的concatenation
        for i in range(contents_num):
            cur_content_encode = content_encode[i] #size=(16,200,256) 当前的第i个content
            cur_Matching_matrix = self.compute_matching_matrix(question_encode,cur_content_encode)  # (batch, question_len , content_len) eg(16,100,200)

            cur_an_matrix = torch.nn.functional.softmax(cur_Matching_matrix,dim=2)  # column wise softmax,对matching matrix每一行归一化和为1 size=(batch, question_len , content_len)
            cur_bn_matrix = torch.nn.functional.softmax(cur_Matching_matrix,dim=1)  # row_wise attention,对matching matrix每一列归一化和为1 size=(batch, question_len , content_len)

            cur_RnQ = self.compute_RnQ(cur_an_matrix, cur_content_encode)  # size=(batch, 100 , 256)     eg[16,100,256]
            cur_RnD = self.compute_RnD(cur_bn_matrix, question_encode)  # size=[16,200,256]
            cur_D_RnD = torch.cat([cur_content_encode, cur_RnD], dim=2)  # size=(16,200,512)
            D_RnD.append(cur_D_RnD)
            RnQ.append(cur_RnQ)
            RnD.append(cur_RnD)

            # 计算matching feature
            cur_max_pooling_feature_row, _ = torch.max(cur_Matching_matrix, dim=1)  # size=(16,200)
            cur_mean_pooling_feature_row = torch.mean(cur_Matching_matrix, dim=1)  # size=(16,200)
            cur_matching_feature_row = torch.stack([cur_max_pooling_feature_row, cur_mean_pooling_feature_row],dim=-1)  # size=(16,200,2)
            matching_feature_row.append(cur_matching_feature_row)

            cur_max_pooling_feature_col, _ = torch.max(cur_Matching_matrix, dim=2)  # size=(16,100)
            cur_mean_pooling_feature_col = torch.mean(cur_Matching_matrix, dim=2)  # size=(16,100)
            cur_matching_feature_col = torch.stack([cur_max_pooling_feature_col, cur_mean_pooling_feature_col],dim=-1)  # size=(16,100,2)
            matching_feature_col.append(cur_matching_feature_col)

        RmD = []  # list[tensor(16,200,512)]
        for i in range(contents_num):
            D_RnD_m = D_RnD[i]  # size=(16,200,512)
            Mmn_i = []
            RmD_i = []
            for j in range(contents_num):
                D_RnD_n = D_RnD[j]  # size=(16,200,512)
                # 计算任意两个文档之间的attention Mmn_i_j size=(16,200,200)
                Mmn_i_j = self.compute_cross_document_attention(D_RnD_m,D_RnD_n)
                Mmn_i.append(Mmn_i_j)

            Mmn_i = torch.stack(Mmn_i,dim=-1)  # size=(16,200,200,10)
            softmax_Mmn_i = self.reduce_softmax(Mmn_i)  # size=(16,200,200,10)

            for j in range(contents_num):
                D_RnD_n = D_RnD[j]  # size=(16,200,512)
                beta_mn_i_j = softmax_Mmn_i[:, :, :, j] # size=(16,200,200)
                cur_RmD = torch.bmm(beta_mn_i_j, D_RnD_n)  # size=(16,200,512)
                RmD_i.append(cur_RmD)

            RmD_i = torch.stack(RmD_i,dim=1)  # size=(16,10,200,512)
            RmD_i = torch.sum(RmD_i, dim=1)  # size=(16,200,512)
            RmD.append(RmD_i)

        reasoning_feature = []
        for i in range(contents_num):
            cur_RnQ = RnQ[i]  # size=(16,100,256)
            cur_RmD = RmD[i]  # size=(16,200,512)
            cur_matching_feature_col = matching_feature_col[i]  # size=(16,100,2)
            cur_matching_feature_row = matching_feature_row[i]  # size=(16,200,2)

            cur_RnQ = torch.cat([cur_RnQ, cur_matching_feature_col], dim=2)  # size=(16,100,258)
            cur_RmD = torch.cat([cur_RmD, cur_matching_feature_row], dim=2)  # size=(16,200,514)

            cur_RnQ_mask = compute_mask(cur_RnQ.mean(dim=2), PreprocessData.padding_idx)
            cur_RmD_mask = compute_mask(cur_RmD.mean(dim=2), PreprocessData.padding_idx)

            gated_cur_RnQ=self.compute_gated_value(cur_RnQ,reasoning_question_gating_val)# size=(16,100,258)
            gated_cur_RmD=self.compute_gated_value(cur_RmD,reasoning_content_gating_val[i])# size=(16,200,514)

            # 经过reasoning层
            cur_RnQ_reasoning_out = self.question_reasoning_layer.forward(gated_cur_RnQ,cur_RnQ_mask)  # size=(16,100,256)
            cur_RmD_reasoning_out= self.content_reasoning_layer.forward(gated_cur_RmD,cur_RmD_mask)  # size=(16,200,256)

            RnQ_reasoning_out_max_pooling,_=torch.max(cur_RnQ_reasoning_out,dim=1)# size=(16,256)
            RmD_reasoning_out_max_pooling,_=torch.max(cur_RmD_reasoning_out,dim=1) #size(16,256)

            cur_reasoning_feature=torch.cat([RnQ_reasoning_out_max_pooling,RmD_reasoning_out_max_pooling],dim=1)# size(16,512)
            reasoning_feature.append(cur_reasoning_feature)

        reasoning_feature=torch.stack( reasoning_feature,dim=1) #size=(16,10,512)
        noise_gate_val=self.decision_gating_layer(reasoning_feature) #size=(16,10,1)
        gated_reasoning_feature=self.compute_gated_value(reasoning_feature,noise_gate_val) #size=(16,10,512)

        reasoning_out_max_feature,_=torch.max(gated_reasoning_feature,dim=1) #size=(16,512)
        reasoning_out_mean_feature=torch.mean(gated_reasoning_feature,dim=1)#size=(16,512)

        decision_input=torch.cat([reasoning_out_max_feature,reasoning_out_mean_feature],dim=1)#size(16,1024)

        decision_output = self.decision_layer.forward(decision_input)  # size=(16,1)
        logics=logics.resize_(logics.size()[0],1)
        output=decision_output*logics

        output=output.view(int(output.size()[0]/5), 5)
        return output # logics 是反向的话乘以-1,正向的话是乘以1