def forward(self, x): mask = compute_mask(x) tmp_emb = self.embedding_layer.forward(x) out_emb = tmp_emb.transpose(0, 1) return out_emb, mask
def collect_fun(self, batch): docs = [] labels = [] for ele in batch: docs.append(ele[0]) labels.append(ele[1]) docs = torch.stack(docs, dim=0) labels = torch.stack(labels, dim=0) # compress on word level docs, _ = del_zeros_right(docs) docs_mask = compute_mask(docs, padding_idx=Vocabulary.PAD_IDX) # compress on sentence level if self.hierarchical: _, sent_right_idx = del_zeros_right(docs_mask.sum(-1)) docs = docs[:, :sent_right_idx, :] docs_mask = docs_mask[:, :sent_right_idx, :] # logger.info('tar_d: {}, {}'.format(docs.dtype, docs.shape)) # logger.info('label: {}, {}'.format(labels.dtype, labels.shape)) return docs, docs_mask, labels
def forward(self, x): # mask 的作用就是生成一个和x维度一样的矩阵,将x中有单词的地方置位1,padding的地方置位0x mask = compute_mask(x, PreprocessData.padding_idx) tmp_emb = self.embedding_layer.forward(x) # 将embed的tensor变成(sequence_len,batch_size,embedding)的样子 out_emb = tmp_emb return out_emb, mask
def forward(self, x): batch_size, seq_len, word_len = x.shape x = x.view(-1, word_len) mask = compute_mask(x, 0) # char-level padding idx is zero x_emb = self.embedding_layer.forward(x) x_emb = x_emb.view(batch_size, seq_len, word_len, -1) mask = mask.view(batch_size, seq_len, word_len) return x_emb, mask
def forward(self, tar_d, cand_ds): tar_d, _ = del_zeros_right(tar_d) cand_ds, _ = del_zeros_right(cand_ds) if self.doc_hierarchical: _, sent_right_idx = del_zeros_right(tar_d.sum(-1)) tar_d = tar_d[:, :sent_right_idx, :] _, sent_right_idx = del_zeros_right(cand_ds.sum(-1)) cand_ds = cand_ds[:, :, :sent_right_idx, :] # embedding layer tar_doc_emb = self.embedding_layer(tar_d) tar_doc_mask = compute_mask(tar_d) cand_docs_emb = self.embedding_layer(cand_ds) cand_docs_mask = compute_mask(cand_ds) # target document encoder layer tar_doc_rep, _ = self.tar_doc_encoder(tar_doc_emb, tar_doc_mask) # candidate documents encoder layer batch, cand_doc_num = cand_docs_emb.size(0), cand_docs_emb.size(1) new_size = [batch * cand_doc_num] + list(cand_docs_emb.shape[2:]) cand_docs_emb_flip = cand_docs_emb.view(*new_size) new_size = [batch * cand_doc_num] + list(cand_ds.shape[2:]) cand_docs_mask_flip = cand_docs_mask.view(*new_size) cand_docs_rep_flip, _ = self.cand_doc_encoder(cand_docs_emb_flip, cand_docs_mask_flip) cand_docs_rep = cand_docs_rep_flip.contiguous().view( batch, cand_doc_num, -1) # output layer cand_scores = torch.bmm(tar_doc_rep.unsqueeze(1), cand_docs_rep.transpose(1, 2)).squeeze( 1) # (batch, cand_doc_num) cand_logits = torch.log_softmax(cand_scores, dim=-1) return cand_logits
def collect_fun(self, batch): tar_d = [] cand_ds = [] pt_label = [] gt_idx = [] for ele in batch: tar_d.append(ele[0]) cand_ds.append(ele[1]) pt_label.append(ele[2]) gt_idx.append(ele[3]) tar_d = torch.stack(tar_d, dim=0) cand_ds = torch.stack(cand_ds, dim=0) pt_label = torch.stack(pt_label, dim=0) gt_idx = torch.tensor(gt_idx, dtype=torch.long) # compress on word level tar_d, _ = del_zeros_right(tar_d) tar_mask = compute_mask(tar_d, padding_idx=Vocabulary.PAD_IDX) cand_ds, _ = del_zeros_right(cand_ds) cand_mask = compute_mask(cand_ds, padding_idx=Vocabulary.PAD_IDX) # compress on sentence level if self.hierarchical: _, sent_right_idx = del_zeros_right(tar_mask.sum(-1)) tar_d = tar_d[:, :sent_right_idx, :] tar_mask = tar_mask[:, :sent_right_idx, :] _, sent_right_idx = del_zeros_right(cand_mask.sum(-1)) cand_ds = cand_ds[:, :, :sent_right_idx, :] cand_mask = cand_mask[:, :, :sent_right_idx, :] # logger.info('tar_d: {}, {}'.format(tar_d.dtype, tar_d.shape)) # logger.info('cand_ds: {}, {}'.format(cand_ds.dtype, cand_ds.shape)) # logger.info('label: {}, {}'.format(pt_label.dtype, pt_label.shape)) return tar_d, tar_mask, cand_ds, cand_mask, pt_label, gt_idx
def forward(self, docs_rep, turn_nl): turn_nl, _ = del_zeros_right(turn_nl) # 1. user response encoding turn_emb = self.embedding_layer(turn_nl) turn_mask = compute_mask(turn_nl) # (batch, hidden_size * 2) turn_word_rep, turn_rep = self.turn_rnn(turn_emb, turn_mask) # 1.1 user response classification # (batch, num_slots) turn_slot_cls = self.turn_cls_linear(turn_rep) turn_slot_cls = torch.softmax(turn_slot_cls, dim=-1) # (batch, 1) turn_inform_sig = self.turn_inform_linear(turn_rep) turn_inform_sig = torch.sigmoid(turn_inform_sig) # (batch, num_slots + 1) turn_slot_inform_cls = torch.cat([(1 - turn_inform_sig) * turn_slot_cls, turn_inform_sig], dim=-1) # 2. compute similar with bi-linear layer batch, num_docs, num_slots, _ = docs_rep.size() # (batch, num_docs, num_slots, hidden_size * 2) turn_rep_expand = turn_rep.unsqueeze(1).unsqueeze(1).expand(-1, num_docs, num_slots, -1).contiguous() # (batch, num_docs, num_slots) turn_docs_slots_similar = self.similar_bilinear(docs_rep, turn_rep_expand).squeeze(-1) extra_similar = turn_docs_slots_similar.new_ones(batch, num_docs, 1) # 2.1 similar with slots weighted # (batch, num_docs, num_slots + 1) turn_docs_slots_similar_extra = torch.cat([turn_docs_slots_similar, extra_similar], dim=-1) # (batch * num_docs, num_slots + 1) turn_docs_slots_similar_extra = turn_docs_slots_similar_extra.reshape(batch * num_docs, -1) # (batch, num_docs, num_slots + 1) turn_slot_inform_cls_expand = turn_slot_inform_cls.unsqueeze(1).expand(-1, num_docs, -1) # (batch * num_docs, num_slots + 1) turn_slot_inform_cls_expand = turn_slot_inform_cls_expand.reshape(batch * num_docs, -1) # (batch * num_docs) turn_docs_similar = torch.bmm(turn_docs_slots_similar_extra.unsqueeze(1), turn_slot_inform_cls_expand.unsqueeze(2)).squeeze(1).squeeze(1) # (batch, num_docs) turn_level_docs_dist = torch.softmax(turn_docs_similar.view(batch, num_docs), dim=-1) return turn_level_docs_dist, turn_slot_cls, turn_inform_sig
def forward(self, doc): doc, _ = del_zeros_right(doc) _, sent_right_idx = del_zeros_right(doc.sum(-1)) doc = doc[:, :sent_right_idx, :] # embedding layer doc_emb = self.embedding_layer(doc) doc_mask = compute_mask(doc) # doc encoder layer tar_doc_rep, _ = self.tar_doc_encoder(doc_emb, doc_mask) cand_doc_rep, _ = self.cand_doc_encoder(doc_emb, doc_mask) # doc representation doc_rep = torch.cat([tar_doc_rep, cand_doc_rep], dim=-1) return doc_rep
def forward(self, contents, question_ans, logics, questions, answers): # contents size(16,10,200) # question_ans size(16,100) batch_size = question_ans.size()[0] max_content_len = contents.size()[2] max_question_len = question_ans.size()[1] contents_num = contents.size()[1] # word-level embedding: (seq_len, batch, embedding_size) question_vec, question_mask = self.fixed_embedding.forward( question_ans) # size=(batch, 100, 200) delta_emb = self.delta_embedding.forward( question_ans) # size=(batch, 100, 200) question_vec += delta_emb # size=(batch, 100, 200) question_encode = self.context_layer.forward( question_vec, question_mask) # size=(batch, 100, 256) contents = contents[:, 0:self. use_content_nums, :] # use top n contents # size=(batch, n, 200) viewed_contents = contents.contiguous().view( batch_size * self.use_content_nums, max_content_len) # size=(batch*n, 200) viewed_content_vec, viewed_content_mask = self.fixed_embedding.forward( viewed_contents) # size=(batch*n, 200,256) viewed_delta_emb = self.delta_embedding.forward( viewed_contents) # size=(batch*n, 200,256) viewed_content_vec += viewed_delta_emb viewed_content_encode = self.context_layer.forward( viewed_content_vec, viewed_content_mask) # size=(batch*n, 200, 256) content_encode = viewed_content_encode.view( batch_size, self.use_content_nums, max_content_len, self.hidden_size * 2).transpose(0, 1) # size=(n,batch, 200, 256) viewed_reasoning_content_gating_val = self.reasoning_gating_layer( viewed_content_encode) # size=(16*n,200,1) reasoning_question_gating_val = self.reasoning_gating_layer( question_encode) # size=(16,100,1) # matching features matching_feature_row = [] # list[tensor(16,200,2)] matching_feature_col = [] # list[tensor(16,100,2)] # compute RnQ & RnD RnQ = [] # list[tensor[16,100,256]] RnD = [] # list[tensor[16,200,256]] D_RnD = [] # 获得D和RnD的concatenation for i in range(self.use_content_nums): cur_content_encode = content_encode[ i] # size=(16,200,256) 当前的第i个content cur_Matching_matrix = self.compute_matching_matrix( question_encode, cur_content_encode ) # (batch, question_len , content_len) eg(16,100,200) cur_an_matrix = torch.nn.functional.softmax( cur_Matching_matrix, dim=2 ) # column wise softmax,对matching matrix每一行归一化和为1 size=(batch, question_len , content_len) cur_bn_matrix = torch.nn.functional.softmax( cur_Matching_matrix, dim=1 ) # row_wise attention,对matching matrix每一列归一化和为1 size=(batch, question_len , content_len) cur_RnQ = self.compute_RnQ( cur_an_matrix, cur_content_encode ) # size=(batch, 100 , 256) eg[16,100,256] cur_RnD = self.compute_RnD(cur_bn_matrix, question_encode) # size=[16,200,256] cur_D_RnD = torch.cat([cur_content_encode, cur_RnD], dim=2) # size=(16,200,512) D_RnD.append(cur_D_RnD) RnQ.append(cur_RnQ) RnD.append(cur_RnD) # 计算matching feature cur_max_pooling_feature_row, _ = torch.max(cur_Matching_matrix, dim=1) # size=(16,200) cur_mean_pooling_feature_row = torch.mean(cur_Matching_matrix, dim=1) # size=(16,200) cur_matching_feature_row = torch.stack( [cur_max_pooling_feature_row, cur_mean_pooling_feature_row], dim=-1) # size=(16,200,2) matching_feature_row.append(cur_matching_feature_row) cur_max_pooling_feature_col, _ = torch.max(cur_Matching_matrix, dim=2) # size=(16,100) cur_mean_pooling_feature_col = torch.mean(cur_Matching_matrix, dim=2) # size=(16,100) cur_matching_feature_col = torch.stack( [cur_max_pooling_feature_col, cur_mean_pooling_feature_col], dim=-1) # size=(16,100,2) matching_feature_col.append(cur_matching_feature_col) RmD = [] # list[tensor(16,200,512)] for i in range(self.use_content_nums): D_RnD_m = D_RnD[i] # size=(16,200,512) Mmn_i = [] RmD_i = [] for j in range(self.use_content_nums): D_RnD_n = D_RnD[j] # size=(16,200,512) # 计算任意两个文档之间的attention Mmn_i_j size=(16,200,200) Mmn_i_j = self.compute_cross_document_attention( D_RnD_m, D_RnD_n) Mmn_i.append(Mmn_i_j) Mmn_i = torch.stack(Mmn_i, dim=-1) # size=(16,200,200,10) softmax_Mmn_i = self.reduce_softmax(Mmn_i) # size=(16,200,200,10) for j in range(self.use_content_nums): D_RnD_n = D_RnD[j] # size=(16,200,512) beta_mn_i_j = softmax_Mmn_i[:, :, :, j] # size=(16,200,200) cur_RmD = torch.bmm(beta_mn_i_j, D_RnD_n) # size=(16,200,512) RmD_i.append(cur_RmD) RmD_i = torch.stack(RmD_i, dim=1) # size=(16,10,200,512) RmD_i = torch.sum(RmD_i, dim=1) # size=(16,200,512) RmD.append(RmD_i) RnQ = torch.stack(RnQ, dim=1) # size 16,n,100,256 viewed_RnQ = RnQ.view(batch_size * self.use_content_nums, max_question_len, self.hidden_size * 2) # size 16*n,100,256 RmD = torch.stack(RmD, dim=1) # size 16,n,200,256 viewed_RmD = RmD.view(batch_size * self.use_content_nums, max_content_len, self.hidden_size * 4) # size 16*n,200,512 viewed_matching_feature_col = torch.stack( matching_feature_col, dim=1).view(batch_size * self.use_content_nums, max_question_len, 2) # size 16*n,100,2 viewed_matching_feature_row = torch.stack( matching_feature_row, dim=1).view(batch_size * self.use_content_nums, max_content_len, 2) # size 16*n,200,2 viewed_RnQ = torch.cat([viewed_RnQ, viewed_matching_feature_col], dim=-1) # size 16*n,100,258 viewed_RmD = torch.cat([viewed_RmD, viewed_matching_feature_row], dim=-1) # size 16*n,200,514 viewed_RnQ_mask = compute_mask( viewed_RnQ.mean(dim=2), PreprocessData.padding_idx) # size 16*n,100,258 viewed_RmD_mask = compute_mask( viewed_RmD.mean(dim=2), PreprocessData.padding_idx) # size 16*n,200,514 viewed_reasoning_question_gating_val = reasoning_question_gating_val.unsqueeze(1) \ .repeat(1, self.use_content_nums, 1, 1) \ .view(batch_size * self.use_content_nums, max_question_len, 1) # size 16*n,100,1 gated_cur_RnQ = self.compute_gated_value( viewed_RnQ, viewed_reasoning_question_gating_val) gated_cur_RmD = self.compute_gated_value( viewed_RmD, viewed_reasoning_content_gating_val) # 经过reasoning层 cur_RnQ_reasoning_out = self.question_reasoning_layer.forward( gated_cur_RnQ, viewed_RnQ_mask) # size=(16*n,100,256) cur_RmD_reasoning_out = self.content_reasoning_layer.forward( gated_cur_RmD, viewed_RmD_mask) # size=(16*n,200,256) # reasoning 层后的pooling 100降维 RnQ_reasoning_out_max_pooling, _ = torch.max(cur_RnQ_reasoning_out, dim=1) # size=(16*n,256) RmD_reasoning_out_max_pooling, _ = torch.max(cur_RmD_reasoning_out, dim=1) # size(16*n,256) viewed_reasoning_feature = torch.cat( [RnQ_reasoning_out_max_pooling, RmD_reasoning_out_max_pooling], dim=1) # size(16*n,512) reasoning_feature = viewed_reasoning_feature.view( batch_size, self.use_content_nums, self.hidden_size * 4) # size=(16,10,512) noise_gate_val = self.decision_gating_layer( reasoning_feature) # size=(16,10,1) gated_reasoning_feature = self.compute_gated_value( reasoning_feature, noise_gate_val) # size=(16,10,512) reasoning_out_max_feature, _ = torch.max(gated_reasoning_feature, dim=1) # size=(16,512) reasoning_out_mean_feature = torch.mean(gated_reasoning_feature, dim=1) # size=(16,512) decision_input = torch.cat( [reasoning_out_max_feature, reasoning_out_mean_feature], dim=1) # size(16,1024) # self attention feature qa_attention = self.self_attention_layer.forward( question_encode) # size=(16,100,50) qa_attention = torch.nn.functional.softmax(qa_attention, dim=1) # size=(16,100,50) qa_match = torch.bmm(question_encode.transpose(1, 2), qa_attention) #size=(16,256,50) qa_matching_feature = torch.sum(qa_match, dim=1) #size=(16,50) decision_input = torch.cat([decision_input, qa_matching_feature], dim=1) # size(16,1024+500) decision_output = self.decision_layer.forward( decision_input) # size=(16,1) logics = logics.resize_(logics.size()[0], 1) output = decision_output * logics output = output.view(int(output.size()[0] / 5), 5) return output # logics 是反向的话乘以-1,正向的话是乘以1
def forward(self, contents, question_ans, logics, contents_char=None, question_ans_char=None): # assert contents_char is not None and question_ans_char is not None batch_size = question_ans.size()[0] max_content_len = contents.size()[2] max_question_len = question_ans.size()[1] contents_num = contents.size()[1] # word-level embedding: (seq_len, batch, embedding_size) content_vec = [] content_mask = [] question_vec, question_mask = self.embedding.forward(question_ans) for i in range(contents_num): cur_content = contents[:, i, :] cur_content_vec, cur_content_mask = self.embedding.forward(cur_content) content_vec.append(cur_content_vec) content_mask.append(cur_content_mask) # char-level embedding: (seq_len, batch, char_embedding_size) # context_emb_char, context_char_mask = self.char_embedding.forward(context_char) # question_emb_char, question_char_mask = self.char_embedding.forward(question_char) question_encode, _ = self.context_layer.forward(question_vec,question_mask) # size=(cur_batch_max_questionans_len, batch, 256) content_encode = [] # word-level encode: (seq_len, batch, hidden_size) for i in range(contents_num): cur_content_vec = content_vec[i] cur_content_mask = content_mask[i] cur_content_encode, _ = self.context_layer.forward(cur_content_vec,cur_content_mask) # size=(cur_batch_max_content_len, batch, 256) content_encode.append(cur_content_encode) # 将所有的content编码后统一到相同的长度 200,所有的question编码后统一到相同的长度100 same_sized_content_encode = [] for i in range(contents_num): cur_content_encode = content_encode[i] cur_content_encode = self.full_matrix_to_specify_size(cur_content_encode, [max_content_len, batch_size,cur_content_encode.size()[2]]) # size=(200,16,256) same_sized_content_encode.append(cur_content_encode) same_sized_question_encode = self.full_matrix_to_specify_size(question_encode, [max_question_len, batch_size,question_encode.size()[2]]) # size=(100,16,256) # 计算gating layer的值 reasoning_content_gating_val = [] reasoning_question_gating_val = None decision_content_gating_val = [] decision_question_gating_val = None for i in range(contents_num): cur_content_encode = same_sized_content_encode[i] # size=(200,16,256) cur_gating_input = cur_content_encode.permute(1,2,0) # size=(16,256,200) cur_reasoning_content_gating_val = self.reasoning_gating_layer(cur_gating_input) # size=(16,1,200) cur_reasoning_content_gating_val =cur_reasoning_content_gating_val+0.00001 # 防止出现gate为0的情况,导致后面padsequence的时候出错 cur_decision_content_gating_val = self.decision_gating_layer(cur_gating_input) # size=(16,1,200) cur_decision_content_gating_val =cur_decision_content_gating_val+0.00001 # 防止出现gate为0的情况,导致后面padsequence的时候出错 reasoning_content_gating_val.append(cur_reasoning_content_gating_val) decision_content_gating_val.append(cur_decision_content_gating_val) question_gating_input = same_sized_question_encode.permute(1,2,0) # size=(16,256,100) reasoning_question_gating_val = self.reasoning_gating_layer(question_gating_input) # size=(16,1,100) reasoning_question_gating_val=reasoning_question_gating_val+0.00001 # 防止出现gate为0的情况,导致后面padsequence的时候出错 decision_question_gating_val = self.decision_gating_layer(question_gating_input) # size=(16,1,100) decision_question_gating_val=decision_question_gating_val+0.00001 # 防止出现gate为0的情况,导致后面padsequence的时候出错 # 计算gate loss todo: 貌似无法返回多个变量,暂时无用 # question_gate_val = torch.cat([reasoning_question_gating_val.view(-1), decision_question_gating_val.view(-1)]) # reasoning_gate_val = torch.cat([ele.view(-1) for ele in reasoning_content_gating_val]) # decision_gate_val = torch.cat([ele.view(-1) for ele in decision_content_gating_val]) # all_gate_val = torch.cat([question_gate_val, reasoning_gate_val, decision_gate_val]) # mean_gate_val = torch.mean(all_gate_val) # Matching Matrix computing, question 和每一个content都要计算matching matrix Matching_matrix = [] for i in range(contents_num): cur_content_encode = same_sized_content_encode[i] cur_Matching_matrix = self.compute_matching_matrix(same_sized_question_encode, cur_content_encode) # (batch, question_len , content_len) eg(16,100,200) Matching_matrix.append(cur_Matching_matrix) # compute an & bn an_matrix = [] bn_matrix = [] for i in range(contents_num): cur_Matching_matrix = Matching_matrix[i] cur_an_matrix = torch.nn.functional.softmax(cur_Matching_matrix,dim=2) # column wise softmax,对matching matrix每一行归一化和为1 size=(batch, question_len , content_len) cur_bn_matrix = torch.nn.functional.softmax(cur_Matching_matrix,dim=1) # row_wise attention,对matching matrix每一列归一化和为1 size=(batch, question_len , content_len) an_matrix.append(cur_an_matrix) bn_matrix.append(cur_bn_matrix) # compute RnQ & RnD RnQ = [] # list[tensor[16,100,256]] RnD = [] for i in range(contents_num): cur_an_matrix = an_matrix[i] cur_content_encode = same_sized_content_encode[i] cur_bn_matrix = bn_matrix[i] cur_RnQ = self.compute_RnQ(cur_an_matrix, cur_content_encode) # size=(batch, curbatch_max_question_len , 256) eg[16,100,256] cur_RnD = self.compute_RnD(cur_bn_matrix,same_sized_question_encode) # size=(batch, curbatch_max_content_len , 256) eg[16,200,256] RnQ.append(cur_RnQ) RnD.append(cur_RnD) ########### compute Mmn' ############## D_RnD = [] # 先获得D和RnD的concatenation for i in range(contents_num): cur_content_encode = same_sized_content_encode[i].transpose(0, 1) # size=(16,200,256) cur_RnD = RnD[i] # size=(16,200,256) # embed() cur_D_RnD = torch.cat([cur_content_encode, cur_RnD], dim=2) # size=(16,200,512) D_RnD.append(cur_D_RnD) RmD = [] # list[tensor(16,200,512)] for i in range(contents_num): D_RnD_m = D_RnD[i] # size=(16,200,512) Mmn_i=[] RmD_i = [] for j in range(contents_num): D_RnD_n = D_RnD[j] # size=(16,200,512) Mmn_i_j = self.compute_cross_document_attention(D_RnD_m,D_RnD_n) # 计算任意两个文档之间的attention Mmn_i_j size=(16,200,200) Mmn_i.append(Mmn_i_j) Mmn_i=torch.stack(Mmn_i).permute(1,2,3,0)# size=(16,200,200,10) softmax_Mmn_i=self.reduce_softmax(Mmn_i) # size=(16,200,200,10) for j in range(contents_num): D_RnD_n = D_RnD[j] # size=(16,200,512) beta_mn_i_j = softmax_Mmn_i[:,:,:,j] cur_RmD = torch.bmm(beta_mn_i_j, D_RnD_n) # size=(16,200,512) RmD_i.append(cur_RmD) RmD_i = torch.stack(RmD_i) # size=(10,16,200,512) RmD_i = RmD_i.transpose(0, 1) # size=(16,10,200,512) RmD_i = torch.sum(RmD_i, dim=1) # size=(16,200,512) RmD.append(RmD_i) # RmD=torch.stack(RmD).transpose(0,1) #size=(16,10,200,512) matching_feature_row = [] # list[tensor(16,200,2)] matching_feature_col = [] # list[tensor(16,100,2)] for i in range(contents_num): cur_Matching_matrix = Matching_matrix[i] # size=(16,100,200) cur_max_pooling_feature_row, _ = torch.max(cur_Matching_matrix, dim=1) # size=(16,200) cur_mean_pooling_feature_row = torch.mean(cur_Matching_matrix, dim=1) # size=(16,200) cur_matching_feature_row = torch.stack([cur_max_pooling_feature_row, cur_mean_pooling_feature_row]).permute(1,2,0) # size=(16,200,2) matching_feature_row.append(cur_matching_feature_row) cur_max_pooling_feature_col, _ = torch.max(cur_Matching_matrix, dim=2) # size=(16,100) cur_mean_pooling_feature_col = torch.mean(cur_Matching_matrix, dim=2) # size=(16,100) cur_matching_feature_col = torch.stack([cur_max_pooling_feature_col, cur_mean_pooling_feature_col]).permute(1,2,0) # size=(16,100,2) matching_feature_col.append(cur_matching_feature_col) # print(253) # embed() reasoning_feature = [] RnQ_reasoning_out=[] RmD_reasoning_out=[] for i in range(contents_num): cur_RnQ = RnQ[i] # size=(16,100,256) cur_RmD = RmD[i] # size=(16,200,512) cur_matching_feature_col = matching_feature_col[i] # size=(16,100,2) cur_matching_feature_row = matching_feature_row[i] # size=(16,200,2) cur_RnQ = torch.cat([cur_RnQ, cur_matching_feature_col], dim=2) # size=(16,100,258) cur_RmD = torch.cat([cur_RmD, cur_matching_feature_row], dim=2) # size=(16,200,514) cur_RnQ_mask = compute_mask(cur_RnQ.mean(dim=2), PreprocessData.padding_idx) cur_RmD_mask = compute_mask(cur_RmD.mean(dim=2), PreprocessData.padding_idx) gated_cur_RnQ=self.compute_gated_value(cur_RnQ,reasoning_question_gating_val)# size=(16,100,258) gated_cur_RmD=self.compute_gated_value(cur_RmD,reasoning_content_gating_val[i])# size=(16,200,514) # 经过reasoning层 cur_RnQ_reasoning_out, _ = self.question_reasoning_layer.forward(gated_cur_RnQ.transpose(0,1),cur_RnQ_mask) # size=(max_sequence_len,16,256) cur_RmD_reasoning_out, _ = self.content_reasoning_layer.forward(gated_cur_RmD.transpose(0,1),cur_RmD_mask) # size=(max_sequence_len,16,256) # 所有的矩阵变成相同的大小 cur_RnQ_reasoning_out = self.full_matrix_to_specify_size(cur_RnQ_reasoning_out, [max_question_len, batch_size, cur_RnQ_reasoning_out.size()[2]]) # size=(100,16,256) cur_RmD_reasoning_out = self.full_matrix_to_specify_size(cur_RmD_reasoning_out, [max_content_len, batch_size, cur_RmD_reasoning_out.size()[2]]) # size=(200,16,256) #过decision layer的gate层 cur_RnQ_reasoning_out=cur_RnQ_reasoning_out.transpose(0,1) #size(16,100,256) cur_RmD_reasoning_out=cur_RmD_reasoning_out.transpose(0,1) #size(16,200,256) RnQ_reasoning_out.append(cur_RnQ_reasoning_out) RmD_reasoning_out.append(cur_RmD_reasoning_out) # gated_RnQ_out=self.compute_gated_value(cur_RnQ_reasoning_out,decision_question_gating_val)#size(16,100,256) # gated_RmD_out=self.compute_gated_value(cur_RmD_reasoning_out,decision_content_gating_val[i])#size(16,200,256) # # # 将2种feature cat到一起得到300*256的表示 # cur_reasoning_feature = torch.cat([gated_RnQ_out, gated_RmD_out], dim=1) # size(16,300,256) || when content=100 size(16,200,256) # reasoning_feature.append(cur_reasoning_feature) # 10个文档的cat到一起 RnQ_reasoning_out=torch.stack(RnQ_reasoning_out).transpose(0,1) #size=(16,10,100,256) RmD_reasoning_out=torch.stack(RmD_reasoning_out).transpose(0,1) #size=(16,10,200,256) RnQ_reasoning_out_maxpool,_=torch.max(RnQ_reasoning_out,dim=1) #size=(16,100,256) RmD_reasoning_out_maxpool,_=torch.max(RmD_reasoning_out,dim=1) #size=(16,200,256) # gated_RnQ_reasoning_out_maxpool=self.compute_gated_value(RnQ_reasoning_out_maxpool,decision_question_gating_val) #size(16,100,256) # gated_RmD_reasoning_out_maxpool=self.compute_gated_value(RmD_reasoning_out_maxpool,decision_content_gating_val) fc_input_RnQ_maxpool,_=torch.max(RnQ_reasoning_out_maxpool,dim=1) # size(16,256) fc_input_RnQ_meanpool=torch.mean(RnQ_reasoning_out_maxpool,dim=1) # size(16,256) fc_input_RmD_maxpool,_=torch.max(RmD_reasoning_out_maxpool,dim=1)#size(16,256) fc_input_RmD_meanpool=torch.mean(RmD_reasoning_out_maxpool,dim=1) #size(16,256) fc_input=torch.cat([fc_input_RnQ_maxpool,fc_input_RnQ_meanpool,fc_input_RmD_maxpool,fc_input_RmD_meanpool],dim=1) #size(16,1024) # reasoning_feature = torch.cat(reasoning_feature, dim=1) # size=(16,3000,256) | when content=100 size(16,2000,256) # # print(299) # # embed() # maxpooling_reasoning_feature_column, _ = torch.max(reasoning_feature, dim=1) # size(16,256) # meanpooling_reasoning_feature_column = torch.mean(reasoning_feature, dim=1) # size(16,256) # # maxpooling_reasoning_feature_row, _ = torch.max(reasoning_feature, dim=2) # size=(16,3000) | when content=100 size(16,2000) # meanpooling_reasoning_feature_row = torch.mean(reasoning_feature, dim=2) # size=(16,3000) | when content=100 size(16,2000) # print(228, "============================") # pooling_reasoning_feature = torch.cat([maxpooling_reasoning_feature_row, meanpooling_reasoning_feature_row, maxpooling_reasoning_feature_column,meanpooling_reasoning_feature_column], dim=1) decision_input=fc_input.view(int(batch_size/5), 5120) # size=(16,1024*5) 五分类问题 # # print(312) # embed() output = self.decision_layer.forward(decision_input) # size=(batchsize/5,5) # temp_gate_val=torch.stack([mean_gate_val,torch.tensor(0.0).to(self.device)]).resize_(1,2) # output_with_gate_val=torch.cat([output,temp_gate_val],dim=0) # logics=logics.resize_(logics.size()[0],1) return output # logics 是反向的话乘以-1,正向的话是乘以1
def forward(self, contents, question_ans, logics, contents_char=None, question_ans_char=None): # assert contents_char is not None and question_ans_char is not None batch_size = question_ans.size()[0] max_content_len = contents.size()[2] max_question_len = question_ans.size()[1] contents_num = contents.size()[1] # word-level embedding: (seq_len, batch, embedding_size) content_vec = [] content_mask = [] question_vec, question_mask = self.fixed_embedding.forward(question_ans) delta_emb = self.delta_embedding.forward(question_ans) question_vec+=delta_emb question_encode = self.context_layer.forward(question_vec, question_mask) # size=(batch, 100, 256) content_encode = [] # word-level encode: (seq_len, batch, hidden_size) for i in range(contents_num): cur_content = contents[:, i, :] cur_content_vec, cur_content_mask = self.fixed_embedding.forward(cur_content) delta_emb = self.delta_embedding.forward(cur_content) cur_content_vec+=delta_emb cur_content_encode = self.context_layer.forward(cur_content_vec, cur_content_mask) # size=(batch, 200, 256) content_encode.append(cur_content_encode) # 计算gating layer的值 reasoning_content_gating_val = [] reasoning_question_gating_val = None for i in range(contents_num): cur_content_encode = content_encode[i] # size=(16,200,256) cur_reasoning_content_gating_val = self.reasoning_gating_layer(cur_content_encode) # size=(16,200,1) cur_reasoning_content_gating_val =cur_reasoning_content_gating_val+(1e-8) # 防止出现gate为0的情况,导致后面padsequence的时候出错 reasoning_content_gating_val.append(cur_reasoning_content_gating_val) reasoning_question_gating_val = self.reasoning_gating_layer(question_encode) # size=(16,100,1) reasoning_question_gating_val=reasoning_question_gating_val+(1e-8) # 防止出现gate为0的情况,导致后面padsequence的时候出错 # 计算gate loss todo: 貌似无法返回多个变量,暂时无用 # question_gate_val = torch.cat([reasoning_question_gating_val.view(-1), decision_question_gating_val.view(-1)]) # reasoning_gate_val = torch.cat([ele.view(-1) for ele in reasoning_content_gating_val]) # decision_gate_val = torch.cat([ele.view(-1) for ele in decision_content_gating_val]) # all_gate_val = torch.cat([question_gate_val, reasoning_gate_val, decision_gate_val]) # mean_gate_val = torch.mean(all_gate_val) # Matching Matrix computing, question 和每一个content都要计算matching matrix #matching features matching_feature_row = [] # list[tensor(16,200,2)] matching_feature_col = [] # list[tensor(16,100,2)] # compute RnQ & RnD RnQ = [] # list[tensor[16,100,256]] RnD = [] # list[tensor[16,200,256]] D_RnD = [] # 获得D和RnD的concatenation for i in range(contents_num): cur_content_encode = content_encode[i] #size=(16,200,256) 当前的第i个content cur_Matching_matrix = self.compute_matching_matrix(question_encode,cur_content_encode) # (batch, question_len , content_len) eg(16,100,200) cur_an_matrix = torch.nn.functional.softmax(cur_Matching_matrix,dim=2) # column wise softmax,对matching matrix每一行归一化和为1 size=(batch, question_len , content_len) cur_bn_matrix = torch.nn.functional.softmax(cur_Matching_matrix,dim=1) # row_wise attention,对matching matrix每一列归一化和为1 size=(batch, question_len , content_len) cur_RnQ = self.compute_RnQ(cur_an_matrix, cur_content_encode) # size=(batch, 100 , 256) eg[16,100,256] cur_RnD = self.compute_RnD(cur_bn_matrix, question_encode) # size=[16,200,256] cur_D_RnD = torch.cat([cur_content_encode, cur_RnD], dim=2) # size=(16,200,512) D_RnD.append(cur_D_RnD) RnQ.append(cur_RnQ) RnD.append(cur_RnD) # 计算matching feature cur_max_pooling_feature_row, _ = torch.max(cur_Matching_matrix, dim=1) # size=(16,200) cur_mean_pooling_feature_row = torch.mean(cur_Matching_matrix, dim=1) # size=(16,200) cur_matching_feature_row = torch.stack([cur_max_pooling_feature_row, cur_mean_pooling_feature_row],dim=-1) # size=(16,200,2) matching_feature_row.append(cur_matching_feature_row) cur_max_pooling_feature_col, _ = torch.max(cur_Matching_matrix, dim=2) # size=(16,100) cur_mean_pooling_feature_col = torch.mean(cur_Matching_matrix, dim=2) # size=(16,100) cur_matching_feature_col = torch.stack([cur_max_pooling_feature_col, cur_mean_pooling_feature_col],dim=-1) # size=(16,100,2) matching_feature_col.append(cur_matching_feature_col) RmD = [] # list[tensor(16,200,512)] for i in range(contents_num): D_RnD_m = D_RnD[i] # size=(16,200,512) Mmn_i = [] RmD_i = [] for j in range(contents_num): D_RnD_n = D_RnD[j] # size=(16,200,512) # 计算任意两个文档之间的attention Mmn_i_j size=(16,200,200) Mmn_i_j = self.compute_cross_document_attention(D_RnD_m,D_RnD_n) Mmn_i.append(Mmn_i_j) Mmn_i = torch.stack(Mmn_i,dim=-1) # size=(16,200,200,10) softmax_Mmn_i = self.reduce_softmax(Mmn_i) # size=(16,200,200,10) for j in range(contents_num): D_RnD_n = D_RnD[j] # size=(16,200,512) beta_mn_i_j = softmax_Mmn_i[:, :, :, j] # size=(16,200,200) cur_RmD = torch.bmm(beta_mn_i_j, D_RnD_n) # size=(16,200,512) RmD_i.append(cur_RmD) RmD_i = torch.stack(RmD_i,dim=1) # size=(16,10,200,512) RmD_i = torch.sum(RmD_i, dim=1) # size=(16,200,512) RmD.append(RmD_i) reasoning_feature = [] for i in range(contents_num): cur_RnQ = RnQ[i] # size=(16,100,256) cur_RmD = RmD[i] # size=(16,200,512) cur_matching_feature_col = matching_feature_col[i] # size=(16,100,2) cur_matching_feature_row = matching_feature_row[i] # size=(16,200,2) cur_RnQ = torch.cat([cur_RnQ, cur_matching_feature_col], dim=2) # size=(16,100,258) cur_RmD = torch.cat([cur_RmD, cur_matching_feature_row], dim=2) # size=(16,200,514) cur_RnQ_mask = compute_mask(cur_RnQ.mean(dim=2), PreprocessData.padding_idx) cur_RmD_mask = compute_mask(cur_RmD.mean(dim=2), PreprocessData.padding_idx) gated_cur_RnQ=self.compute_gated_value(cur_RnQ,reasoning_question_gating_val)# size=(16,100,258) gated_cur_RmD=self.compute_gated_value(cur_RmD,reasoning_content_gating_val[i])# size=(16,200,514) # 经过reasoning层 cur_RnQ_reasoning_out = self.question_reasoning_layer.forward(gated_cur_RnQ,cur_RnQ_mask) # size=(16,100,256) cur_RmD_reasoning_out= self.content_reasoning_layer.forward(gated_cur_RmD,cur_RmD_mask) # size=(16,200,256) RnQ_reasoning_out_max_pooling,_=torch.max(cur_RnQ_reasoning_out,dim=1)# size=(16,256) RmD_reasoning_out_max_pooling,_=torch.max(cur_RmD_reasoning_out,dim=1) #size(16,256) cur_reasoning_feature=torch.cat([RnQ_reasoning_out_max_pooling,RmD_reasoning_out_max_pooling],dim=1)# size(16,512) reasoning_feature.append(cur_reasoning_feature) reasoning_feature=torch.stack( reasoning_feature,dim=1) #size=(16,10,512) noise_gate_val=self.decision_gating_layer(reasoning_feature) #size=(16,10,1) gated_reasoning_feature=self.compute_gated_value(reasoning_feature,noise_gate_val) #size=(16,10,512) reasoning_out_max_feature,_=torch.max(gated_reasoning_feature,dim=1) #size=(16,512) reasoning_out_mean_feature=torch.mean(gated_reasoning_feature,dim=1)#size=(16,512) decision_input=torch.cat([reasoning_out_max_feature,reasoning_out_mean_feature],dim=1)#size(16,1024) decision_output = self.decision_layer.forward(decision_input) # size=(16,1) logics=logics.resize_(logics.size()[0],1) output=decision_output*logics output=output.view(int(output.size()[0]/5), 5) return output # logics 是反向的话乘以-1,正向的话是乘以1