def get_ck_local(self, hop, story, story_size, domains): embed = _cuda(torch.zeros((story_size + (self.embedding_dim, )))) for i, domain in enumerate(domains): embed[i] = self.__getattribute__('C_{}_'.format(domain))[hop]( story.contiguous()[i]) embed = torch.sum(embed, 2).squeeze(2) return embed
def forward(self, input_seqs, input_lengths): # input_lengths是该batch中,每个故事的长度 # 感觉不将维度压缩更符合推测,因为需要相邻元素需要计算相似度,如果将MEM_TOKEN_SIZE压缩进去反而会导致词的语义被分割 embeddings = self.embedding( input_seqs ) # [batch_size, story_length, MEM_TOKEN_SIZE, hidden_size] # 保持batch_size维度不变,另外两个维度合并,然后在embedding # embeddings = self.embedding(input_seqs.contiguous().view(input_seqs.size(0), -1).long()) # embeddings = embeddings.view(input_seqs.size() + (embeddings.size(-1),)) # 添加一个维度 embeddings = torch.sum(embeddings, 2) # [batch_size, story_length, hidden_size] embeddings = self.dropout_layer(embeddings) # 为什么要使用dropout # 随机丢弃一些embedding的特征(embedding的每一维就是一个特征),防止过拟合 hidden_init = _cuda( torch.zeros( 2 * self.n_layers, input_seqs.size(0), self.hidden_size)) # [2, batch_size, hidden_size]隐含状态的初始值 if input_lengths: embeddings = nn.utils.rnn.pack_padded_sequence(embeddings, input_lengths, batch_first=True) output, hidden = self.gru( embeddings, hidden_init) # output [] hidden [2, batch_size, hidden_size] # outputs (seq_len, batch, num_directions * hidden_size) hidden [2 8 128](num_layers * num_directions, batch, hidden_size) if input_lengths: # 消除pack_padded_sequence的填充 output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True) # hidden = self.W(torch.cat((hidden[0], hidden[1]), dim=1)) output = self.W(output) # [batch_size, story_length, hidden_size] return output, hidden
def forward(self, input_seqs, input_lengths): embedded = self.embedding(input_seqs.contiguous().view( input_seqs.size(0), -1).long()) embedded = embedded.view(input_seqs.size() + (embedded.size(-1), )) embedded = torch.sum(embedded, 2).squeeze(2) embedded = self.dropout_layer(embedded.transpose(0, 1)) global_outputs, global_hidden = self.global_gru( embedded, input_lengths) local_outputs = [] mask = _cuda(torch.zeros((len(input_lengths), input_lengths[0]))) for i, length in enumerate(input_lengths): mask[i, :length] = 1 for domain in self.domains: local_rnn = getattr(self, '{}_gru'.format(domain)) local_output, _ = local_rnn(embedded, input_lengths) local_outputs.append(local_output) local_outputs, scores = self.mix_attention( torch.stack(local_outputs, dim=-1), mask) outputs = self.MLP_H( torch.cat((F.dropout(local_outputs, self.dropout, self.training), F.dropout(global_outputs, self.dropout, self.training)), dim=-1)) hidden = self.selfatten(outputs, input_lengths) outputs_ = self.W(outputs) hidden_ = self.W(hidden) label = self.global_classifier(global_outputs) return outputs_, hidden_, label, scores
def get_state(self, bsz): """Get cell states and hidden states.""" return _cuda(torch.zeros(2, bsz, self.hidden_size))
def forward(self, extKnow, story_size, story_lengths, copy_list, encode_hidden, target_batches, max_target_length, batch_size, use_teacher_forcing, get_decoded_words, global_pointer): # Initialize variables for vocab and pointer all_decoder_outputs_vocab = _cuda( torch.zeros(max_target_length, batch_size, self.num_vocab)) all_decoder_outputs_ptr = _cuda( torch.zeros(max_target_length, batch_size, story_size[1])) decoder_input = _cuda(torch.LongTensor([SOS_token] * batch_size)) memory_mask_for_step = _cuda(torch.ones(story_size[0], story_size[1])) decoded_fine, decoded_coarse = [], [] hidden = self.relu(self.projector(encode_hidden)).unsqueeze(0) # Start to generate word-by-word for t in range(max_target_length): embed_q = self.dropout_layer(self.C(decoder_input)) # b * e if len(embed_q.size()) == 1: embed_q = embed_q.unsqueeze(0) _, hidden = self.sketch_rnn(embed_q.unsqueeze(0), hidden) query_vector = hidden[0] p_vocab = self.attend_vocab(self.C.weight, hidden.squeeze(0)) all_decoder_outputs_vocab[t] = p_vocab _, topvi = p_vocab.data.topk(1) # query the external konwledge using the hidden state of sketch RNN prob_soft, prob_logits = extKnow(query_vector, global_pointer) all_decoder_outputs_ptr[t] = prob_logits if use_teacher_forcing: decoder_input = target_batches[:, t] else: decoder_input = topvi.squeeze() if get_decoded_words: search_len = min(5, min(story_lengths)) prob_soft = prob_soft * memory_mask_for_step _, toppi = prob_soft.data.topk(search_len) temp_f, temp_c = [], [] for bi in range(batch_size): token = topvi[bi].item() #topvi[:,0][bi].item() temp_c.append(self.lang.index2word[token]) if '@' in self.lang.index2word[token]: cw = 'UNK' for i in range(search_len): if toppi[:, i][bi] < story_lengths[bi] - 1: cw = copy_list[bi][toppi[:, i][bi].item()] break temp_f.append(cw) if args['record']: memory_mask_for_step[bi, toppi[:, i][bi].item()] = 0 else: temp_f.append(self.lang.index2word[token]) decoded_fine.append(temp_f) decoded_coarse.append(temp_c) return all_decoder_outputs_vocab, all_decoder_outputs_ptr, decoded_fine, decoded_coarse
def forward(self, context_hidden, context_outputs, context_lengths, context_mask, \ context_entity, context_entity_lengths, context_entity_mask, context_entity_id, \ kb_entity, kb_entity_id, kb_entity_row, kb_entity_lengths, kb_entity_mask, \ entity, entity_lengths, entity_mask, entity_plain, entity_type, \ target_batches, max_target_length, schedule_sampling, get_decoded_words): batch_size, entity_set_length = entity.size(0), entity.size(1) #context_entity_id = context_entity_id + context_entity_mask.long() * (entity_set_length-1) #kb_entity_id = kb_entity_id + kb_entity_mask.long() * (entity_set_length-1) # Initialize variables for vocab and pointer all_decoder_outputs_vocab = _cuda( torch.zeros(max_target_length, batch_size, self.num_vocab)) all_decoder_outputs_ptr = _cuda( torch.zeros(max_target_length, batch_size, entity_set_length)) decoder_input = _cuda(torch.LongTensor([SOS_token] * batch_size)) memory_mask_for_step = _cuda(torch.ones(batch_size, entity_set_length)) decoded_fine, decoded_coarse = [], [] dec_hidden = self.relu(self.projector(context_hidden)) # Start to generate word-by-word for t in range(max_target_length): pre_emb = self.dropout_layer(self.embedder(decoder_input)) # b * e if len(pre_emb.size()) == 1: pre_emb = pre_emb.unsqueeze(0) _, dec_hidden = self.gru(pre_emb.unsqueeze(0), dec_hidden) # For context distribution p_entity_context = _cuda(torch.zeros(batch_size, entity_set_length)) context_entity_hidden, context_entity_pro = self.context_entity_attention(dec_hidden.transpose(0,1), \ context_entity, mask=context_entity_mask, return_weights=True) p_entity_context.scatter_add_(1, context_entity_id, context_entity_pro.squeeze(1)) # For KB distribution p_entity_kb = _cuda(torch.zeros(batch_size, entity_set_length)) ## Row-level kb_entity_row_onehot = to_onehot(kb_entity_row, mask=kb_entity_mask).transpose( 1, 2) # B x maxR x maxE kb_entity_row_hidden = torch.bmm(kb_entity_row_onehot, kb_entity) # B x maxR x h kb_entity_row_sum = kb_entity_row_onehot.sum( 2, keepdim=True, dtype=torch.float) # B x maxR x 1 kb_entity_row_mask = kb_entity_row_sum.squeeze(2).eq(0) kb_entity_row_sum = torch.clamp(kb_entity_row_sum, min=1) kb_entity_row_hidden = kb_entity_row_hidden / kb_entity_row_sum kb_entity_hidden, kb_entity_row_pro = self.kb_entity_attention(dec_hidden.transpose(0,1), \ kb_entity_row_hidden, mask=kb_entity_row_mask, return_weights=True) kb_entity_row_pro = torch.bmm( kb_entity_row_pro, kb_entity_row_onehot).squeeze(1) # B x maxE ## Entity-level kb_entity_logit = self.kb_entity_attention(dec_hidden.transpose(0,1), \ kb_entity, return_weights_only=True) # B x maxE x 1 kb_entity_logit = kb_entity_logit * kb_entity_row_onehot # B x maxR x maxE #kb_entity_logit.masked_fill_(torch.logical_not(kb_entity_row_onehot.bool()), -1e9) kb_entity_logit.masked_fill_(1 - kb_entity_row_onehot.byte(), -1e9) #kb_entity_logit = kb_entity_logit - (1 - kb_entity_row_onehot) * 1e10 kb_entity_pro = F.softmax(kb_entity_logit, dim=2) kb_entity_pro = torch.gather(kb_entity_pro, 1, kb_entity_row.unsqueeze(1)).squeeze(1) #kb_entity_pro = kb_entity_pro.sum(1) kb_entity_pro = kb_entity_pro * kb_entity_row_pro p_entity_kb.scatter_add_(1, kb_entity_id, kb_entity_pro) """ kb_entity_hidden, kb_entity_logit = self.kb_entity_attention(dec_hidden.transpose(0,1), \ kb_entity, mask=kb_entity_mask, return_weights=True) kb_entity_logit = kb_entity_logit.squeeze(1) p_entity_kb.scatter_add_(1, kb_entity_id, kb_entity_logit) """ switch_input = self.switch(dec_hidden.squeeze(0)) #pro_switch = self.softmax(switch_input) #if not get_decoded_words: # pro_switch = nn.functional.gumbel_softmax(switch_input, tau=1.0 - (epoch / 15.0), hard=False) #else: # pro_switch = nn.functional.gumbel_softmax(switch_input, tau=1.0 - (epoch / 15.0), hard=True) #p_entity = torch.cat((p_entity_context.unsqueeze(2), p_entity_kb.unsqueeze(2)), dim=2) #p_entity = torch.bmm(p_entity, pro_switch.unsqueeze(2)).squeeze(2) pro_switch = self.sigmoid(switch_input) p_entity = ( 1.0 - pro_switch) * p_entity_context + pro_switch * p_entity_kb # For Vocab vocab_attn = self.context_attention(dec_hidden.transpose(0, 1), context_outputs, mask=context_mask) #entity_hidden = torch.cat((context_entity_hidden, kb_entity_hidden), dim=1) #entity_hidden = torch.bmm(pro_switch.unsqueeze(1), entity_hidden) entity_hidden = context_entity_hidden.squeeze(1) * ( 1 - pro_switch) + kb_entity_hidden.squeeze(1) * pro_switch #concat_input = torch.cat((dec_hidden.squeeze(0), vocab_attn.squeeze(1)), dim=1) concat_input = torch.cat( (dec_hidden.squeeze(0), vocab_attn.squeeze(1), entity_hidden.squeeze(1)), dim=1) concat_output = torch.tanh(self.concat(concat_input)) #p_vocab = self.attend_vocab(self.embedder.weight, concat_output) p_vocab = self.vocab_matrix(concat_output) all_decoder_outputs_vocab[t] = p_vocab all_decoder_outputs_ptr[t] = p_entity use_teacher_forcing = random.random() < schedule_sampling if use_teacher_forcing: decoder_input = target_batches[:, t] else: _, topvi = p_vocab.data.topk(1) decoder_input = topvi.squeeze() if get_decoded_words: prob_soft = self.softmax(p_entity) search_len = min(5, min(entity_lengths)) prob_soft = prob_soft * memory_mask_for_step _, toppi = prob_soft.data.topk(search_len) temp_f, temp_c = [], [] for bi in range(batch_size): token = topvi[bi].item() #topvi[:,0][bi].item() temp_c.append(self.vocab.index2word[token]) if '@' in self.vocab.index2word[token]: slot = self.vocab.index2word[token] cw = 'UNK' for i in range(search_len): top_index = toppi[:, i][bi].item() #if top_index < entity_lengths[bi]-1 and entity_type[bi][top_index] == slot: if top_index < entity_lengths[bi] - 1: cw = entity_plain[bi][toppi[:, i][bi].item()] break temp_f.append(cw) if args['record']: memory_mask_for_step[bi, toppi[:, i][bi].item()] = 0 else: temp_f.append(self.vocab.index2word[token]) decoded_fine.append(temp_f) decoded_coarse.append(temp_c) return all_decoder_outputs_vocab, all_decoder_outputs_ptr, decoded_fine, decoded_coarse
def forward(self, kb_ent, extKnow, context, context_mask, copy_list, encode_hidden, target_batches, max_target_length, schedule_sampling, get_decoded_words): batch_size = len(copy_list) story_size = max([len(seq) for seq in copy_list]) extKnow_mask, _ = mask_and_length(kb_ent, PAD_token) # Initialize variables for vocab and pointer all_decoder_outputs_vocab = _cuda( torch.zeros(max_target_length, batch_size, self.num_vocab)) all_decoder_outputs_ptr = _cuda( torch.zeros(max_target_length, batch_size, story_size)) decoder_input = _cuda(torch.LongTensor([SOS_token] * batch_size)) decoded_fine, decoded_coarse = [], [] #hidden = self.relu(self.projector(encode_hidden)).unsqueeze(0) hidden = self.tanh(self.projector(encode_hidden)).unsqueeze(0) # Start to generate word-by-word for t in range(max_target_length + 1): rnn_input_list, concat_input_list = [], [] embed_q = self.dropout_layer(self.embedder(decoder_input)) # b * e if len(embed_q.size()) == 1: embed_q = embed_q.unsqueeze(0) rnn_input_list.append(embed_q) rnn_input = torch.cat(rnn_input_list, dim=1) _, hidden = self.gru(rnn_input.unsqueeze(0), hidden) concat_input_list.append(hidden.squeeze(0)) #get knowledge attention knowledge_outputs = self.knowledge_attention(hidden.transpose( 0, 1), extKnow, mask=extKnow_mask) concat_input_list.append(knowledge_outputs.squeeze(1)) #get context attention context_outputs = self.context_attention(hidden.transpose(0, 1), context, mask=context_mask) concat_input_list.append(context_outputs.squeeze(1)) #concat_input = torch.cat((hidden.squeeze(0), context_outputs.squeeze(1), knowledge_outputs.squeeze(1)), dim=1) concat_input = torch.cat(concat_input_list, dim=1) concat_output = torch.tanh(self.concat(concat_input)) if t < max_target_length: #p_vocab = self.attend_vocab(self.C.weight, concat_output) p_vocab = self.vocab_matrix(concat_output) all_decoder_outputs_vocab[t] = p_vocab if t > 0: p_entity = self.entity_ranking(concat_output.unsqueeze(1), extKnow, mask=extKnow_mask).squeeze(1) all_decoder_outputs_ptr[t - 1] = p_entity if t < max_target_length: use_teacher_forcing = random.random() < schedule_sampling if use_teacher_forcing: decoder_input = target_batches[:, t] else: _, topvi = p_vocab.data.topk(1) decoder_input = topvi.squeeze() # Start to generate word-by-word if get_decoded_words: for t in range(max_target_length): p_vocab = all_decoder_outputs_vocab[t] p_entity = all_decoder_outputs_ptr[t] _, topvi = p_vocab.data.topk(1) search_len = min(5, story_size) _, toppi = p_entity.data.topk(search_len) temp_f, temp_c = [], [] for bi in range(batch_size): token = topvi[bi].item() #topvi[:,0][bi].item() temp_c.append(self.lang.index2word[token]) if '@' in self.lang.index2word[token]: cw = 'UNK' for i in range(search_len): #if toppi[:,i][bi] < story_lengths[bi]-1: if toppi[:, i][bi] > 0 and toppi[:, i][bi] < len( copy_list[bi]): cw = copy_list[bi][toppi[:, i][bi].item()] break temp_f.append(cw) else: temp_f.append(self.lang.index2word[token]) decoded_fine.append(temp_f) decoded_coarse.append(temp_c) return all_decoder_outputs_vocab, all_decoder_outputs_ptr, decoded_fine, decoded_coarse
def flipByLength(self, input, lengths): output = _cuda(torch.zeros_like(input)) for i, l in enumerate(lengths): output[i, :l, :] = torch.flip(input[i, :l, :], (0, )) return output
def forward(self, extKnow, story_size, story_lengths, copy_list, encode_hidden, target_batches, max_target_length, batch_size, use_teacher_forcing, get_decoded_words, global_pointer, H=None, global_entity_type=None, domains=None): # Initialize variables for vocab and pointer all_decoder_outputs_vocab = _cuda( torch.zeros(max_target_length, batch_size, self.num_vocab)) all_decoder_outputs_ptr = _cuda( torch.zeros(max_target_length, batch_size, story_size[1])) decoder_input = _cuda(self.domain_emb(domains.view(-1, ))) + self.C( _cuda(torch.LongTensor([SOS_token] * batch_size))) memory_mask_for_step = _cuda(torch.ones(story_size[0], story_size[1])) decoded_fine, decoded_coarse = [], [] hidden = self.relu(self.projector(encode_hidden)).unsqueeze(0) hidden_locals = [] for i in range(len(self.domains)): hidden_locals.append(hidden.clone()) mask = _cuda(torch.ones((len(story_lengths), 1))) global_hiddens = [] local_hiddens = [] scores = [] # Start to generate word-by-word for t in range(max_target_length): if t != 0: decoder_input = self.C(decoder_input) embed_q = self.dropout_layer(decoder_input) if len(embed_q.size()) == 2: embed_q = embed_q.unsqueeze(0) _, hidden = self.sketch_rnn_global(embed_q, hidden) hidden_locals_ = [] for domain in self.domains.values(): hidden_locals_.append(self.sketch_rnn_local[domain]( embed_q, hidden_locals[domain])[1]) hidden_locals = hidden_locals_ hidden_local, score = self.mix_attention( torch.stack(hidden_locals, dim=-1).transpose(0, 1), mask) hidden_local, score = hidden_local.transpose(0, 1), score.transpose( 0, 1) scores.append(score) query_vector = self.MLP( torch.cat( (F.dropout(hidden, self.dropout, self.training), F.dropout(hidden_local, self.dropout, self.training)), dim=-1)) global_hiddens.append(hidden) local_hiddens.append(hidden_local) p_vocab, context = self.get_p_vocab(query_vector[0], H) all_decoder_outputs_vocab[t] = p_vocab _, topvi = p_vocab.data.topk(1) # query the external konwledge using the hidden state of sketch RNN prob_soft, prob_logits = extKnow(context[0], global_pointer) all_decoder_outputs_ptr[t] = prob_logits if use_teacher_forcing: decoder_input = target_batches[:, t] else: decoder_input = topvi.squeeze() if get_decoded_words: search_len = min(5, min(story_lengths)) prob_soft = prob_soft * memory_mask_for_step _, toppi = prob_soft.data.topk(search_len) temp_f, temp_c = [], [] for bi in range(batch_size): token = topvi[bi].item() temp_c.append(self.lang.index2word[token]) if '@' in self.lang.index2word[token]: gold_type = self.lang.index2word[token] cw = 'UNK' for i in range(search_len): if toppi[:, i][bi] < story_lengths[bi] - 1: cw = copy_list[bi][toppi[:, i][bi].item()] break temp_f.append(cw) if args['record']: memory_mask_for_step[bi, toppi[:, i][bi].item()] = 0 else: temp_f.append(self.lang.index2word[token]) decoded_fine.append(temp_f) decoded_coarse.append(temp_c) label = self.global_classifier( torch.cat(global_hiddens, dim=0).transpose(0, 1)) scores = torch.cat(scores, dim=0).transpose(0, 1).contiguous() return all_decoder_outputs_vocab, all_decoder_outputs_ptr, decoded_fine, decoded_coarse, label, scores
def forward(self, story_size, ext_know, global_ptr, story_length, max_target_length, batch_size, encoded_hidden, evaluating, copy_list, use_teacher_forcing, response_target): record = _cuda(torch.ones(story_size[0], story_size[1])) # [batch_size, story_length] # all_decoder_output_ptr输出的是局部指针,是针对当前对话来说的 all_decoder_output_ptr = _cuda( torch.zeros(max_target_length, batch_size, story_size[1])) # 针对当前对话 # all_decoder_output_vocab 针对的是词汇表 all_decoder_output_vocab = _cuda( torch.zeros(max_target_length, batch_size, self.num_vocab)) # 针对词汇表 decoder_input = _cuda(torch.LongTensor( [SOS_token] * batch_size)) # 每次为同一个batch的样本生成一个单词 hidden_init = self.relu(self.projector(encoded_hidden)).unsqueeze( 0) # 对连接降维, 为什么要添加relu decoded_fine, decoded_coarse = [], [] # 使用sketch RNN逐字生成输出 for t in range(max_target_length): sketch_response = self.dropout_layer( self.C(decoder_input)) #[8] -> [1,8,128] . if len(sketch_response.size() ) == 1: # batch_size==1的时候会出现维度只有一位的情况 sketch_response = sketch_response.unsqueeze(0) # 这里的seq_len为什么设置1? _, hidden = self.sketch_rnn( sketch_response.unsqueeze(0), hidden_init) # [seq_len, batch_size, embedding_dim] query = hidden[ 0] # [num_layers * num_directions, batch, embedding_dim] 我认为结果包含了各层的隐含态 # p_vocab [batch_size, vocab_size] # 论文对p_vocab进行了softmax操作,但是实际代码注释了,因为会使得效果变得比较差 ''' C 的维度是[词汇表长度, embedding_dim], 从词向量矩阵中计算注意力得分(未归一化), 因为embed_layer包含了词汇表所有的词汇的表示,而文本经过embed_layer得到的就是与文本长度有关的嵌入矩阵 ''' p_vocab = hidden.squeeze(0).matmul(self.C.weight.transpose( 1, 0)) # 这里添加softmax层导致效果变差 # p_vocab = self.attend_vocab(self.C.weight, hidden.squeeze(0)) # p_vocab [vocab_size, embedding_dim] all_decoder_output_vocab[t] = p_vocab _, top_p_vocab = p_vocab.data.topk(1) # # 使用sketch rnn的最后隐含态查询EK得到注意力分布,也就是local pointer local_ptr, prob_soft = ext_know(query, global_ptr) # 针对整个文本计算注意力分布,然后从中抄词 all_decoder_output_ptr[t] = local_ptr if use_teacher_forcing: # 使用了标签数据进行初始化,算不算数据泄露? decoder_input = response_target[:, t] else: decoder_input = top_p_vocab.squeeze( ) # 使用这个来不断改变sketch_response,之前就是这里的问题 if evaluating: search_len = min(5, min(story_length)) prob_soft = prob_soft * record _, top_p_soft = prob_soft.data.topk(search_len) tmp_f, tmp_c = [], [] for bi in range(batch_size): token = top_p_vocab[bi].item() tmp_c.append(self.word_map.index2word[token]) if '@' in self.word_map.index2word[ token]: #'@R_cuisine','@R_location','@R_number','@R_price' cw = 'UNK' # 改为数值 for i in range(search_len): if top_p_soft[bi][i] < story_length[ bi] - 1: # top_p_soft[i][bi] -> top_p_soft[:, i][bi] cw = copy_list[bi][top_p_soft[bi][i].item()] break tmp_f.append(cw) # 这个是放在循环外面 if args['record']: record[bi][top_p_soft[bi] [i].item()] = 0 # copy_list中已经使用的部分清零 else: tmp_f.append(self.word_map.index2word[token] ) # 如果不是那几个‘@’的话,则记录单词 decoded_fine.append(tmp_f) decoded_coarse.append(tmp_c) return all_decoder_output_vocab, all_decoder_output_ptr, decoded_fine, decoded_coarse