def decode(self, dec_state): """ decode """ long_tensor_type = torch.cuda.LongTensor if self.use_gpu else torch.LongTensor b = dec_state.get_batch_size() # [[0], [k*1], [k*2], ..., [k*(b-1)]] self.pos_index = (long_tensor_type(range(b)) * self.k).view(-1, 1) # Inflate the initial hidden states to be of size: (b*k, H) dec_state = dec_state.inflate(self.k) # Initialize the scores; for the first step, # ignore the inflated copies to avoid duplicate entries in the top k sequence_scores = long_tensor_type(b * self.k).float() sequence_scores.fill_(-float('inf')) sequence_scores.index_fill_( 0, long_tensor_type([i * self.k for i in range(b)]), 0.0) # Initialize the input vector input_var = long_tensor_type([self.BOS] * b * self.k) # Store decisions for backtracking stored_scores = list() stored_predecessors = list() stored_emitted_symbols = list() for t in range(1, self.max_length + 1): # Run the RNN one step forward output, dec_state, attn = self.model.decode(input_var, dec_state) log_softmax_output = output.squeeze(1) # To get the full sequence scores for the new candidates, add the # local scores for t_i to the predecessor scores for t_(i-1) sequence_scores = sequence_scores.unsqueeze(1).repeat(1, self.V) if self.length_average and t > 1: sequence_scores = sequence_scores * \ (1 - 1/t) + log_softmax_output / t else: sequence_scores += log_softmax_output scores, candidates = sequence_scores.view(b, -1).topk(self.k, dim=1) # Reshape input = (b*k, 1) and sequence_scores = (b*k) input_var = (candidates % self.V) sequence_scores = scores.view(b * self.k) input_var = input_var.view(b * self.k) # Update fields for next timestep if torch.__version__ == '1.2.0': predecessors = (candidates / self.V + self.pos_index.expand_as(candidates)).view( b * self.k) else: predecessors = (torch.true_divide(candidates, self.V) + self.pos_index.expand_as(candidates)).view( b * self.k).long() dec_state = dec_state.index_select(predecessors) # Update sequence scores and erase scores for end-of-sentence symbol so that they aren't expanded stored_scores.append(sequence_scores.clone()) eos_indices = input_var.data.eq(self.EOS) if eos_indices.nonzero(as_tuple=False).dim() > 0: sequence_scores.data.masked_fill_(eos_indices, -float('inf')) if self.ignore_unk: # Erase scores for UNK symbol so that they aren't expanded unk_indices = input_var.data.eq(self.UNK) if unk_indices.nonzero(as_tuple=False).dim() > 0: sequence_scores.data.masked_fill_(unk_indices, -float('inf')) # Cache results for backtracking stored_predecessors.append(predecessors) stored_emitted_symbols.append(input_var) predicts, scores, lengths = self._backtrack(stored_predecessors, stored_emitted_symbols, stored_scores, b) predicts = predicts[:, :1] scores = scores[:, :1] lengths = long_tensor_type(lengths)[:, :1] mask = sequence_mask(lengths, max_len=self.max_length).eq(0) predicts[mask] = self.PAD return predicts, lengths, scores
def encode(self, enc_inputs, hidden=None): """ encode """ outputs = Pack() enc_outputs, enc_hidden = self.encoder(enc_inputs, hidden) inputs, lengths = enc_inputs batch_size = enc_outputs.size(0) max_len = enc_outputs.size(1) attn_mask = sequence_mask(lengths, max_len).eq(0) if self.with_bridge: enc_hidden = self.bridge(enc_hidden) # insert dialog memory if self.dialog_state_memory is None: assert self.dialog_history_memory is None assert self.history_index is None assert self.memory_masks is None self.dialog_state_memory = enc_outputs self.dialog_history_memory = enc_outputs self.history_index = inputs self.memory_masks = attn_mask else: batch_state_memory = self.dialog_state_memory[:batch_size, :, :] self.dialog_state_memory = torch.cat( [batch_state_memory, enc_outputs], dim=1) batch_history_memory = self.dialog_history_memory[: batch_size, :, :] self.dialog_history_memory = torch.cat( [batch_history_memory, enc_outputs], dim=1) batch_history_index = self.history_index[:batch_size, :] self.history_index = torch.cat([batch_history_index, inputs], dim=-1) batch_memory_masks = self.memory_masks[:batch_size, :] self.memory_masks = torch.cat([batch_memory_masks, attn_mask], dim=-1) batch_kb_inputs = self.kbs[:batch_size, :, :] batch_kb_state_memory = self.kb_state_memory[:batch_size, :, :] batch_kb_slot_memory = self.kb_slot_memory[:batch_size, :, :] batch_kb_slot_index = self.kb_slot_index[:batch_size, :] kb_mask = self.kb_mask[:batch_size, :] selector_mask = self.selector_mask[:batch_size, :] batch_situation = self.situation[:, :batch_size, :] batch_user_profile = self.user_profile[:batch_size, :, :] batch_user_profile_mask = self.user_profile_mask[:batch_size, :] enc_hidden = self.situation_bridge( torch.cat([enc_hidden, batch_situation], dim=-1)) up_memory, up_readout = self.decoder.initialize_user_profile( batch_user_profile, enc_hidden, batch_user_profile_mask) enc_hidden = self.user_profile_bridge( torch.cat([enc_hidden, up_readout.unsqueeze(0)], dim=-1)) kb_memory, selector, kb_readout = self.decoder.initialize_kb_v3( batch_kb_inputs, enc_hidden, kb_mask) enc_hidden = self.kb_readout_bridge( torch.cat([enc_hidden, kb_readout.unsqueeze(0)], dim=-1)) dec_init_state = self.decoder.initialize_state( hidden=enc_hidden, state_memory=self.dialog_state_memory, history_memory=self.dialog_history_memory, kb_memory=kb_memory, kb_state_memory=batch_kb_state_memory, kb_slot_memory=batch_kb_slot_memory, history_index=self.history_index, kb_slot_index=batch_kb_slot_index, attn_mask=self.memory_masks, attn_kb_mask=kb_mask, selector=selector, selector_mask=selector_mask, up_readout=up_readout) return outputs, dec_init_state
def iterate(self, turn_inputs, kb_inputs, situation_inputs=None, user_profile_inputs=None, optimizer=None, grad_clip=None, use_rl=False, is_training=True): """ iterate """ self.reset_memory() self.load_kb_memory(kb_inputs) self.load_situation_memory(situation_inputs) self.load_user_profile_memory(kb_inputs, user_profile_inputs) metrics_list = [] total_loss = 0 for i, inputs in enumerate(turn_inputs): if self.use_gpu: inputs = inputs.cuda() src, src_lengths = inputs.src tgt, tgt_lengths = inputs.tgt task_label = inputs.task gold_entity = inputs.gold_entity ptr_index, ptr_lengths = None, None kb_index, kb_index_lengths = inputs.kb_index enc_inputs = src[:, 1:-1], src_lengths - 2 # filter <bos> <eos> dec_inputs = tgt[:, :-1], tgt_lengths - 1 # filter <eos> target = tgt[:, 1:] # filter <bos> target_mask = sequence_mask(tgt_lengths - 1) if use_rl: sample_outputs = self.sample(enc_inputs, dec_inputs, random_sample=True) with torch.no_grad(): greedy_outputs = self.sample(enc_inputs, dec_inputs, random_sample=False) outputs = self.forward(enc_inputs, dec_inputs) metrics = self.collect_rl_metrics(sample_outputs, greedy_outputs, target, gold_entity, ptr_index, kb_index, target_mask, task_label) else: outputs = self.forward(enc_inputs, dec_inputs) metrics = self.collect_metrics(outputs, target, ptr_index, kb_index) metrics_list.append(metrics) total_loss += metrics.loss self.update_memory(dialog_state_memory=outputs.dialog_state_memory, kb_state_memory=outputs.kb_state_memory) if torch.isnan(total_loss): raise ValueError("NAN loss encountered!") if is_training: assert optimizer is not None optimizer.zero_grad() total_loss.backward() if grad_clip is not None and grad_clip > 0: torch.nn.utils.clip_grad_norm_(parameters=self.parameters(), max_norm=grad_clip) optimizer.step() return metrics_list
def iterate(self, turn_inputs, kb_inputs, optimizer=None, grad_clip=None, is_training=True, method="GAN", mask=False): """ iterate note: this function iterate in the whole model (muti-agent) instead of single sub_model """ if isinstance(optimizer, tuple): optimizerG, optimizerDB, optimizerDE = optimizer # clear all memory before the begin of a new batch computation for name, model in self.named_children(): if name.startswith("model_"): model.reset_memory() model.load_kb_memory(kb_inputs) # store the whole model (muti_agent)'s metric metrics_list_S, metrics_list_TB, metrics_list_TE = [], [], [] metrics_list_G, metrics_list_DB, metrics_list_DE = [], [], [] mask_list_S, length_list = [], [] # store the whole model (muti_agent)'s loss total_loss_DB, total_loss_DE, total_loss_G = 0, 0, 0 # use to compute final loss (sum of each agent's loss) per turn for the cumulated total_loss in a batch loss = Pack() # use to store kb_mask for three single model kd_masks = Pack() # compare evaluation metric (bleu/f1score) among models if method in ('1-3', 'GAN'): # TODO complete bleu_ENS_gt_S, bleu_ENS_gt_TB, f1score_ENS_gt_TE = True, True, True else: # compute bleu_S_gt_TB per batch (compute metric for the following training batch) # (key: batch/following/training) res_bleu = self.compare_metric(generator_1=self.generator_S, generator_2=self.generator_TB, turn_inputs=turn_inputs, kb_inputs=kb_inputs, type='bleu', data_name=self.data_name) if isinstance(res_bleu, tuple): bleu_S_gt_TB, bleu_S_gt_TB_str = res_bleu else: assert isinstance(res_bleu, bool) bleu_S_gt_TB, bleu_S_gt_TB_str = res_bleu, '' if self.model_TE is not None: res_f1score = self.compare_metric( generator_1=self.generator_S, generator_2=self.generator_TE, turn_inputs=turn_inputs, kb_inputs=kb_inputs, type='f1score', data_name=self.data_name) if isinstance(res_f1score, tuple): f1score_S_gt_TE, f1score_S_gt_TE_str = res_f1score else: assert isinstance(res_f1score, bool) f1score_S_gt_TE, f1score_S_gt_TE_str = res_f1score, '' """ update discriminator """ # clear all memory again because of cumulation of the memory in the computation of the above generator for name, model in self.named_children(): if name.startswith("model_"): model.reset_memory() model.load_kb_memory(kb_inputs) # begin iterate (a dialogue batch) for i, inputs in enumerate(turn_inputs): for name, model in self.named_children(): if name.startswith("model_"): if model.use_gpu: inputs = inputs.cuda() src, src_lengths = inputs.src tgt, tgt_lengths = inputs.tgt task_label = inputs.task gold_entity = inputs.gold_entity ptr_index, ptr_lengths = inputs.ptr_index kb_index, kb_index_lengths = inputs.kb_index enc_inputs = src[:, 1: -1], src_lengths - 2 # filter <bos> <eos> dec_inputs = tgt[:, :-1], tgt_lengths - 1 # filter <eos> target = tgt[:, 1:] # filter <bos> target_mask = sequence_mask(tgt_lengths - 1) kd_mask = sequence_kd_mask(tgt_lengths - 1, target, name, self.ent_idx, self.nen_idx) outputs = model.forward(enc_inputs, dec_inputs) metrics = model.collect_metrics(outputs, target, ptr_index, kb_index) if name == "model_S": metrics_list_S.append(metrics) elif name == "model_TB": metrics_list_TB.append(metrics) else: metrics_list_TE.append(metrics) kd_masks[name] = kd_mask if mask else target_mask loss[name] = metrics model.update_memory( dialog_state_memory=outputs.dialog_state_memory, kb_state_memory=outputs.kb_state_memory) # store necessary data for three single model if self.model_TE is not None: kd_mask_e = kd_masks.model_TE kd_mask_s = kd_masks.model_S kd_mask_b = kd_masks.model_TB mask_list_S.append(kd_mask_s) length_list.append(tgt_lengths - 1) assert False not in (kd_mask_b == kd_mask_e) errD_B = self.discriminator_update(netD=self.discriminator_B, real_data=loss.model_TB.prob, fake_data=loss.model_S.prob, lengths=tgt_lengths - 1, mask=kd_mask_b) errD_E = self.discriminator_update(netD=self.discriminator_E, real_data=loss.model_TE.prob, fake_data=loss.model_S.prob, lengths=tgt_lengths - 1, mask=kd_mask_e) # collect discriminator‘s total loss metrics_DB = Pack(num_samples=metrics.num_samples) metrics_DE = Pack(num_samples=metrics.num_samples) metrics_DB.add(loss=errD_B, logits=0.0, prob=0.0) metrics_DE.add(loss=errD_E, logits=0.0, prob=0.0) metrics_list_DB.append(metrics_DB) metrics_list_DE.append(metrics_DE) # update in a batch total_loss_DB = total_loss_DB + errD_B total_loss_DE = total_loss_DE + errD_E loss.clear() kd_masks.clear() # check loss if torch.isnan(total_loss_DB) or torch.isnan(total_loss_DE): raise ValueError("NAN loss encountered!") # compute and update gradient if is_training: assert not None in (optimizerDB, optimizerDE) optimizerDB.zero_grad() optimizerDE.zero_grad() total_loss_DB.backward() total_loss_DE.backward() if grad_clip is not None and grad_clip > 0: torch.nn.utils.clip_grad_norm_( parameters=self.discriminator_B.parameters(), max_norm=grad_clip) torch.nn.utils.clip_grad_norm_( parameters=self.discriminator_E.parameters(), max_norm=grad_clip) optimizerDB.step() optimizerDE.step() """ update generator """ # begin iterate (a dialogue batch) n_turn = len(metrics_list_S) assert n_turn == len(turn_inputs) == len(mask_list_S) for i in range(n_turn): errG, errG_B, errG_E, nll = self.generator_update( netG=self.model_S, netDB=self.discriminator_B, netDE=self.discriminator_E, fake_data=metrics_list_S[i].prob, length=length_list[i], mask=mask_list_S[i], nll=metrics_list_S[i].loss, lambda_g=self.lambda_g) # collect generator‘s total loss metrics_G = Pack(num_samples=metrics_list_S[i].num_samples) metrics_G.add(loss=errG, loss_gb=errG_B, loss_ge=errG_E, loss_nll=nll, logits=0.0, prob=0.0) metrics_list_G.append(metrics_G) # update in a batch total_loss_G += errG # check loss if torch.isnan(total_loss_G): raise ValueError("NAN loss encountered!") # compute and update gradient if is_training: assert optimizerG is not None optimizerG.zero_grad() total_loss_G.backward() if grad_clip is not None and grad_clip > 0: torch.nn.utils.clip_grad_norm_( parameters=self.model_S.parameters(), max_norm=grad_clip) optimizerG.step() return metrics_list_S, metrics_list_G, metrics_list_DB, metrics_list_DE
def encode(self, enc_inputs, hidden=None): """ encode """ outputs = Pack() enc_outputs, enc_hidden = self.encoder(enc_inputs, hidden) inputs, lengths = enc_inputs batch_size = enc_outputs.size(0) max_len = enc_outputs.size(1) attn_mask = sequence_mask(lengths, max_len).eq(0) if self.with_bridge: enc_hidden = self.bridge(enc_hidden) # insert dialog memory if self.dialog_state_memory is None: assert self.dialog_history_memory is None assert self.history_index is None assert self.memory_masks is None self.dialog_state_memory = enc_outputs self.dialog_history_memory = enc_outputs self.history_index = inputs self.memory_masks = attn_mask else: batch_state_memory = self.dialog_state_memory[:batch_size, :, :] self.dialog_state_memory = torch.cat( [batch_state_memory, enc_outputs], dim=1) batch_history_memory = self.dialog_history_memory[: batch_size, :, :] self.dialog_history_memory = torch.cat( [batch_history_memory, enc_outputs], dim=1) batch_history_index = self.history_index[:batch_size, :] self.history_index = torch.cat([batch_history_index, inputs], dim=-1) batch_memory_masks = self.memory_masks[:batch_size, :] self.memory_masks = torch.cat([batch_memory_masks, attn_mask], dim=-1) batch_kb_inputs = self.kbs[:batch_size, :, :] batch_kb_state_memory = self.kb_state_memory[:batch_size, :, :] batch_kb_slot_memory = self.kb_slot_memory[:batch_size, :, :] batch_kb_slot_index = self.kb_slot_index[:batch_size, :] kb_mask = self.kb_mask[:batch_size, :] selector_mask = self.selector_mask[:batch_size, :] selector = self.decoder.initialize_kb_v2( enc_hidden=enc_hidden, kb_state_memory=batch_kb_state_memory, attn_kb_mask=kb_mask) # kb_memory, selector = self.decoder.initialize_kb_v3(kb_inputs=batch_kb_inputs, enc_hidden=enc_hidden) kb_memory = None dec_init_state = self.decoder.initialize_state( hidden=enc_hidden, state_memory=self.dialog_state_memory, history_memory=self.dialog_history_memory, kb_memory=kb_memory, kb_state_memory=batch_kb_state_memory, kb_slot_memory=batch_kb_slot_memory, history_index=self.history_index, kb_slot_index=batch_kb_slot_index, attn_mask=self.memory_masks, attn_kb_mask=kb_mask, selector=selector, selector_mask=selector_mask) return outputs, dec_init_state
def forward(self, inputs, state): """ forward """ inputs, lengths = inputs batch_size, max_len = inputs.size() out_inputs = inputs.new_zeros(size=(batch_size, max_len, self.out_input_size), dtype=torch.float) fact_len = state.fact.size(1) hist_len = state.hist.size(1) out_facts = inputs.new_zeros(size=(batch_size, max_len, fact_len), dtype=torch.float) out_hists = inputs.new_zeros(size=(batch_size, max_len, hist_len), dtype=torch.float) # prob_hist = inputs.new_zeros( # size=(batch_size, max_len, self.output_size), # dtype=torch.float) # prob_fact = inputs.new_zeros( # size=(batch_size, max_len, self.output_size), # dtype=torch.float) # sort by lengths sorted_lengths, indices = lengths.sort(descending=True) inputs = inputs.index_select(0, indices) state = state.index_select(indices) # number of valid input (i.e. not padding index) in each time step num_valid_list = sequence_mask(sorted_lengths).int().sum(dim=0) for i, num_valid in enumerate(num_valid_list): dec_input = inputs[:num_valid, i] valid_state = state.slice_select(num_valid) out_input, valid_state, output = self.decode(dec_input, valid_state, is_training=True) state.hidden[:, :num_valid] = valid_state.hidden out_inputs[:num_valid, i] = out_input.squeeze(1) out_facts[:num_valid, i] = output.attn_f.squeeze(1) out_hists[:num_valid, i] = output.attn_h.squeeze(1) # Resort _, inv_indices = indices.sort() state = state.index_select(inv_indices) out_inputs = out_inputs.index_select(0, inv_indices) out_facts = out_facts.index_select(0, inv_indices) out_hists = out_hists.index_select(0, inv_indices) p_modes = self.ff(out_inputs) # (batch_size, max_len, vocab_size) prob_vocab = self.output_layer(out_inputs) # prob_hist = convert_dist( # out_hists, state.hist, prob_hist) # prob_fact = convert_dist( # out_facts, state.fact, prob_fact) # a = torch.cat((prob_vocab, prob_hist, prob_fact), - # 1).view(batch_size * max_len, self.output_size, -1) # b = p_modes.view(batch_size * max_len, -1).unsqueeze(2) # prob = torch.bmm(a, b).squeeze().view(batch_size, max_len, -1) weighted_prob = prob_vocab * p_modes[:, :, 0].unsqueeze(2) weighted_f = out_facts * p_modes[:, :, 1].unsqueeze(2) weighted_h = out_hists * p_modes[:, :, 2].unsqueeze(2) weighted_prob = convert_dist(weighted_h, state.hist, weighted_prob) weighted_prob = convert_dist(weighted_f, state.fact, weighted_prob) log_probs = torch.log(weighted_prob + 1e-10) return log_probs, state, output
def forward(self, dec_inputs, state): """ forward """ inputs, lengths = dec_inputs batch_size, max_len = inputs.size() out_inputs = inputs.new_zeros(size=(batch_size, max_len, self.out_input_size), dtype=torch.float) kb_inputs = inputs.new_zeros(size=(batch_size, max_len, self.out_input_size), dtype=torch.float) out_attn_size = state.history_memory.size(1) out_attn_probs = inputs.new_zeros(size=(batch_size, max_len, out_attn_size), dtype=torch.float) out_kb_size = state.kb_slot_memory.size(1) out_kb_probs = inputs.new_zeros(size=(batch_size, max_len, out_kb_size), dtype=torch.float) # sort by lengths sorted_lengths, indices = lengths.sort(descending=True) inputs = inputs.index_select(0, indices) state = state.index_select(indices) # number of valid inputs (i.e. not padding index) in each time step num_valid_list = sequence_mask(sorted_lengths).int().sum(dim=0) for i, num_valid in enumerate(num_valid_list): dec_input = inputs[:num_valid, i] valid_state = state.slice_select(num_valid) # decode for one step out_input, kb_input, attn, kb_attn, valid_state = self.decode( dec_input, valid_state, is_training=True) state.hidden[:, :num_valid] = valid_state.hidden state.state_memory[:num_valid, :, :] = valid_state.state_memory state.kb_state_memory[: num_valid, :, :] = valid_state.kb_state_memory out_inputs[:num_valid, i] = out_input.squeeze(1) kb_inputs[:num_valid, i] = kb_input.squeeze(1) out_attn_probs[:num_valid, i] = attn.squeeze(1) out_kb_probs[:num_valid, i] = kb_attn.squeeze(1) # Resort _, inv_indices = indices.sort() state = state.index_select(inv_indices) out_inputs = out_inputs.index_select(0, inv_indices) kb_inputs = kb_inputs.index_select(0, inv_indices) attn_probs = out_attn_probs.index_select(0, inv_indices) kb_probs = out_kb_probs.index_select(0, inv_indices) probs = self.output_layer(out_inputs) p_gen = self.gate_layer(out_inputs) p_con = self.copy_gate_layer(kb_inputs) return probs, attn_probs, kb_probs, p_gen, p_con, state
def encode(self, inputs, hidden=None, is_training=False): """ encode """ ''' #inputs: 嵌套形式为{分离src和target和cue->(分离数据和长度->tensor数据值 #{'src':( 数据值-->shape(batch_size , sen_num , max_len), 句子长度值--> shape(batch_size,sen_num) ), 'tgt':( 数据值-->shape(batch_size , max_len), 句子长度值-->shape(batch_size) ), 'cue' :( 数据值-->shape(batch_size, max_len), 句子长度值--> shape(batch_size) ), 'label':( 数据值-->shape(batch_size , max_len), 句子长度值-->shape(batch_size) ), 'index': ( 数据值-->shape(batch_size , max_len), 句子长度值-->shape(batch_size) ) } ''' outputs = Pack() ''' 第二阶段''' if self.task_id == 1: enc_inputs = inputs.src[0][:, 1:-1], inputs.src[1] - 2 lengths = inputs.src[1] - 2 # (batch_size) enc_outputs, enc_hidden, enc_embedding = self.encoder( enc_inputs, hidden) # enc_outputs:(batch_size, max_len-2, 2*rnn_hidden_size) # enc_hidden:(num_layer , batch_size , 2*rnn_hidden_size) if self.with_bridge: enc_hidden = self.bridge(enc_hidden) # tem_bth,tem_len,tem_hi_size =enc_outputs.size()# batch_size, max_len-2, 2*rnn_hidden_size) key_index, len_key_index = inputs.index[0], inputs.index[ 1] # key_index(batch_size , idx_max_len) max_len = key_index.size(1) key_mask = sequence_mask(len_key_index, max_len).eq( 0) # key_mask(batch_size , idx_max_len) key_hidden = torch.gather( enc_embedding, 1, key_index.unsqueeze(-1).repeat(1, 1, enc_embedding.size( -1))) # (batch_size ,idx_max_len, 2*rnn_hidden_size) key_global = key_hidden.masked_fill( key_mask.unsqueeze(-1), 0.0).sum(1) / len_key_index.unsqueeze(1).float() key_global = self.key_linear( key_global) # (batch_size, 2*rnn_hidden_size) # persona_aware = torch.cat([key_global, enc_hidden[-1]], dim=-1) # (batch_size ,2*rnn_hidden_size) persona_aware = key_global + enc_hidden[ -1] #(batch_size , 2*rnn_hidden_size) # persona batch_size, sent_num, sent = inputs.cue[0].size() cue_len = inputs.cue[1] # (batch_size,sen_num) cue_len[cue_len > 0] -= 2 # (batch_size, sen_num) cue_inputs = inputs.cue[0].view(-1, sent)[:, 1:-1], cue_len.view(-1) # cue_inputs:((batch_size*sent_num , max_len-2),(batch_size*sent_num)) cue_enc_outputs, cue_enc_hidden, _ = self.persona_encoder( cue_inputs, hidden) # cue_enc_outputs:(batch_size*sent_num , max_len-2, 2*rnn_hidden_size) # cue_enc_hidden:(层数 , batch_size*sent_num, 2 * rnn_hidden_size) cue_outputs = cue_enc_hidden[-1].view(batch_size, sent_num, -1) cue_enc_outputs = cue_enc_outputs.view( batch_size, sent_num, cue_enc_outputs.size(1), -1 ) # cue_enc_outputs:(batch_size, sent_num , max_len-2, 2*rnn_hidden_size) cue_len = cue_len.view(batch_size, sent_num) # cue_outputs:(batch_size, sent_num, 2 * rnn_hidden_size) # Attention weighted_cue1, cue_attn1 = self.persona_attention( query=persona_aware.unsqueeze(1), memory=cue_outputs, mask=inputs.cue[1].eq(0)) # weighted_cue:(batch_size , 1 , 2 * rnn_hidden_size) persona_memory1 = weighted_cue1 + persona_aware.unsqueeze(1) weighted_cue2, cue_attn2 = self.persona_attention( query=persona_memory1, memory=cue_outputs, mask=inputs.cue[1].eq(0)) persona_memory2 = weighted_cue2 + persona_aware.unsqueeze(1) weighted_cue3, cue_attn3 = self.persona_attention( query=persona_memory2, memory=cue_outputs, mask=inputs.cue[1].eq(0)) cue_attn = cue_attn3.squeeze(1) # cue_attn:(batch_size, sent_num) outputs.add(cue_attn=cue_attn) indexs = cue_attn.max(dim=1)[1] # (batch_size) if is_training: # gumbel_attn = F.gumbel_softmax(torch.log(cue_attn + 1e-10), 0.1, hard=True) # persona = torch.bmm(gumbel_attn.unsqueeze(1), cue_outputs) # indexs = gumbel_attn.max(-1)[1] # cue_lengths = cue_len.gather(1, indexs.unsqueeze(1)).squeeze(1) # (batch_size) persona = cue_enc_outputs.gather( 1, indexs.view(-1, 1, 1, 1).repeat( 1, 1, cue_enc_outputs.size(2), cue_enc_outputs.size(3))).squeeze( 1) # (batch_size , max_len-2, 2*rnn_hidden_size) cue_lengths = cue_len.gather(1, indexs.unsqueeze(1)).squeeze( 1) # (batch_size) else: persona = cue_enc_outputs.gather( 1, indexs.view(-1, 1, 1, 1).repeat( 1, 1, cue_enc_outputs.size(2), cue_enc_outputs.size(3))).squeeze( 1) # (batch_size , max_len-2, 2*rnn_hidden_size) cue_lengths = cue_len.gather(1, indexs.unsqueeze(1)).squeeze( 1) # (batch_size) outputs.add(indexs=indexs) outputs.add(attn_index=inputs.label) # (batch_size) dec_init_state = self.decoder.initialize_state( hidden=enc_hidden, attn_memory=enc_outputs if self.attn_mode else None, memory_lengths=lengths if self.attn_mode else None, # (batch_size) cue_enc_outputs= persona, # (batch_size, max_len-2, 2*rnn_hidden_size) cue_lengths=cue_lengths, # (batch_size) task_id=self.task_id) # if 'index' in inputs.keys(): # outputs.add(attn_index=inputs.index) elif self.task_id == 0: ''' 第一阶段''' # enc_inputs:((batch_size,max_len-2), (batch_size-2))**src去头去尾 # hidden:None batch_size, sent_num, sent_len = inputs.src[0].size() src_lengths = inputs.src[1] # (batch_size,sent_num) src_lengths[src_lengths > 0] -= 2 # src_lengths(batch_size, sent_num) src_inputs = inputs.src[0].view( -1, sent_len)[:, 1:-1], src_lengths.view(-1) # src_inputs:((batch_size*sent_num , max_len-2),(batch_size*sent_num)) src_enc_outputs, enc_hidden, _ = self.encoder(src_inputs, hidden) if self.with_bridge: enc_hidden = self.bridge(enc_hidden) # src_enc_outputs:(batch_size*sent_num , max_len-2, 2*rnn_hidden_size) # enc_hidden:(层数 , batch_size*sent_num, 2 * rnn_hidden_size) src_outputs = torch.mean( enc_hidden.view(self.num_layers, batch_size, sent_num, -1), 2) # 池化 # src_outputs:(层数,batch_size, 2 * rnn_hidden_size) # persona:((batch_size,max_len-2), (batch_size))**persona的Tensor去头去尾 cue_inputs = inputs.cue[0][:, 1:-1], inputs.cue[1] - 2 cue_lengths = inputs.cue[1] - 2 # (batch_size) cue_enc_outputs, cue_enc_hidden, _ = self.persona_encoder( cue_inputs, hidden) # cue_enc_outputs:(batch_size, max_len-2, 2*rnn_hidden_size) # cue_enc_hidden:(num_layer , batch_size , 2*rnn_hidden_size) dec_init_state = self.decoder.initialize_state( hidden=src_outputs, attn_memory=src_enc_outputs.view( batch_size, sent_num, sent_len - 2, -1) if self.attn_mode else None, # (batch_size, sent_num , max_len-2, 2*rnn_hidden_size) memory_lengths=src_lengths if self.attn_mode else None, # (batch_size,sent_num) cue_enc_outputs= cue_enc_outputs, # (batch_size, max_len-2, 2*rnn_hidden_size) cue_lengths=cue_lengths, task_id=self.task_id # (batch_size) ) return outputs, dec_init_state