def collect_metrics(self, outputs): """ collect_metrics """ pos_logits = outputs.pos_logits pos_target = torch.ones_like(pos_logits) neg_logits = outputs.neg_logits neg_target = torch.zeros_like(neg_logits) pos_loss = F.binary_cross_entropy_with_logits(pos_logits, pos_target, reduction='none') neg_loss = F.binary_cross_entropy_with_logits(neg_logits, neg_target, reduction='none') loss = (pos_loss + neg_loss).mean() pos_acc = torch.sigmoid(pos_logits).gt(0.5).float().mean() neg_acc = torch.sigmoid(neg_logits).lt(0.5).float().mean() margin = (torch.sigmoid(pos_logits) - torch.sigmoid(neg_logits)).mean() metrics = Pack(loss=loss, pos_acc=pos_acc, neg_acc=neg_acc, margin=margin) num_samples = pos_target.size(0) metrics.add(num_samples=num_samples) return metrics
def collect_rl_metrics(self, sample_outputs, greedy_outputs, target, gold_entity, entity_dir): """ collect rl training metrics """ num_samples = target.size(0) rl_metrics = Pack(num_samples=num_samples) loss = 0 # log prob for sampling and greedily generation logits = sample_outputs.logits sample = sample_outputs.pred_word greedy_logits = greedy_outputs.logits greedy_sample = greedy_outputs.pred_word # cal reward sample_reward, _, _ = self.reward_fn(sample, target, gold_entity, entity_dir) greedy_reward, bleu_score, f1_score = self.reward_fn(greedy_sample, target, gold_entity, entity_dir) reward = sample_reward - greedy_reward # cal RL loss sample_log_prob = self.nll_loss(logits, sample, mask=sample.ne(self.padding_idx), reduction=False, matrix=False) # [batch_size, max_len] nll = sample_log_prob * reward.to(sample_log_prob.device) nll = getattr(torch, self.nll_loss.reduction)(nll.sum(dim=-1)) loss += nll # gen report rl_acc = accuracy(greedy_logits, target, padding_idx=self.padding_idx) if reward.dim() == 2: reward = reward.sum(dim=-1) rl_metrics.add(loss=loss, reward=reward.mean(), rl_acc=rl_acc, bleu_score=bleu_score.mean(), f1_score=f1_score.mean()) return rl_metrics
def decode(self, input, state, is_training=False): """ decode """ hidden = state.hidden rnn_input_list = [] cue_input_list = [] out_input_list = [] output = Pack() if self.embedder is not None: input = self.embedder(input) # shape: (batch_size, 1, input_size) input = input.unsqueeze(1) rnn_input_list.append(input) cue_input_list.append(state.knowledge) if self.feature_size is not None: feature = state.feature.unsqueeze(1) rnn_input_list.append(feature) cue_input_list.append(feature) if self.attn_mode is not None: attn_memory = state.attn_memory attn_mask = state.attn_mask query = hidden[-1].unsqueeze(1) weighted_context, attn = self.attention(query=query, memory=attn_memory, mask=attn_mask) rnn_input_list.append(weighted_context) cue_input_list.append(weighted_context) out_input_list.append(weighted_context) output.add(attn=attn) rnn_input = torch.cat(rnn_input_list, dim=-1) rnn_output, rnn_hidden = self.rnn(rnn_input, hidden) cue_input = torch.cat(cue_input_list, dim=-1) cue_output, cue_hidden = self.cue_rnn(cue_input, hidden) h_y = self.tanh(self.fc1(rnn_hidden)) h_cue = self.tanh(self.fc2(cue_hidden)) if self.concat: new_hidden = self.fc3(torch.cat([h_y, h_cue], dim=-1)) else: k = self.sigmoid(self.fc3(torch.cat([h_y, h_cue], dim=-1))) new_hidden = k * h_y + (1 - k) * h_cue #out_input_list.append(new_hidden.transpose(0, 1)) # bug fixed out_input_list.append(new_hidden[-1].unsqueeze(1)) out_input = torch.cat(out_input_list, dim=-1) state.hidden = new_hidden if is_training: return out_input, state, output else: log_prob = self.output_layer(out_input) return log_prob, state, output
def generate(self,batch_iter, num_batches): self.model.eval() itoemo = ['NORM', 'POS', 'NEG'] with torch.no_grad(): results = [] for inputs in batch_iter: enc_inputs = inputs dec_inputs = inputs.num_tgt_input enc_outputs=Pack() outputs = self.model.forward(enc_inputs, dec_inputs, hidden=None) outputs=outputs.logits preds = outputs.max(dim=2) # news_id = inputs.id tgt_raw = inputs.raw_tgt preds = preds[1].tolist() temp_a_1=[] emo_b_1=[] temp = [] tgt_emo = inputs.tgt_emo[0].tolist() for a, b, c in zip( tgt_raw, preds, tgt_emo): # enc_outputs.add(preds=preds, scores=scores, emos=emos, target_emos=temp) # result_batch = enc_outputs.flatten() # results += result_batch a = a[1:] temp_a = [] emo_b = [] emo_c=[] for i, entity in enumerate(a): temp_a.append(entity) emo_b.append(itoemo[b[i]]) emo_c.append(itoemo[c[i]]) # tgt_raw=tgt_raw[1:] assert len(temp_a) == len(emo_b) assert len(emo_c) == len(emo_b) temp_a_1.append([temp_a]) #pred1 emo_b_1.append([emo_b]) # emo temp.append(emo_c) # temp = [] # tgt_emo = inputs.tgt_emo[0].tolist() # for item in tgt_emo: # temp.append([itoemo[x] for x in item]) # print(emo_b_1) # print(temp) if hasattr(inputs, 'id') and inputs.id is not None: enc_outputs.add(id=inputs['id']) enc_outputs.add(tgt=tgt_raw, preds=temp_a_1, emos=emo_b_1, target_emos=temp) result_batch=enc_outputs.flatten() results+=result_batch return results
def encode(self, inputs, hidden=None, is_training=False): """ encode """ ''' inputs: src, topic_src, topic_tgt, [tgt] ''' outputs = Pack() enc_inputs = _, lengths = inputs.src[0][:, 1:-1], inputs.src[1] - 2 enc_outputs, enc_hidden = self.encoder(enc_inputs, hidden) if self.with_bridge: enc_hidden = self.bridge(enc_hidden) guide_score = enc_hidden[-1] decoder_init_state = enc_hidden if self.use_ntm: bow_src = inputs.bow ntm_stat = self.ntm(bow_src) outputs.add(ntm_loss=ntm_stat['loss']) embedding_weight = self.topic_embedder.weight topic_word_logit = self.ntm.topics.get_topic_word_logit() EPS = 1e-12 w = topic_word_logit.transpose(0, 1) # V x K nv = embedding_weight / ( torch.norm(embedding_weight, dim=1, keepdim=True) + EPS) # V x dim nw = w / (torch.norm(w, dim=0, keepdim=True) + EPS) # V x K t = nv.transpose(0, 1) @ w # dim x K t = t.transpose(0, 1) topic_feature_input = [] if 'S' in self.decoder_attention_channels: src_labels = F.one_hot(inputs.topic_src_label, num_classes=self.topic_num) # B * K src_topics = src_labels.float() @ t topic_feature_input.append(src_topics) if 'T' in self.decoder_attention_channels: tgt_labels = F.one_hot(inputs.topic_tgt_label, num_classes=self.topic_num) tgt_topics = tgt_labels.float() @ t topic_feature_input.append(tgt_topics) topic_feature = self.t_to_feature( torch.cat(topic_feature_input, dim=-1)) dec_init_state = self.decoder.initialize_state( hidden=decoder_init_state, attn_memory=enc_outputs, memory_lengths=lengths, guide_score=guide_score, topic_feature=topic_feature) return outputs, dec_init_state
def decode(self, input, state, is_training=False): last_hidden = state.hidden rnn_input_list = [] output = Pack() embed_q = self.C[0](input).unsqueeze(1) # b * e --> b * 1 * e batch_size = input.size(0) rnn_input_list.append(embed_q) if self.attn_mode is not None: attn_memory = state.attn_memory attn_mask = state.attn_mask query = last_hidden[-1].unsqueeze(1) weighted_context, attn = self.attention(query=query, memory=attn_memory, mask=attn_mask) rnn_input_list.append(weighted_context) output.add(attn=attn) rnn_input = torch.cat(rnn_input_list, dim=-1) output, new_hidden = self.gru(rnn_input, last_hidden) state.hidden = new_hidden u = [new_hidden[0].squeeze()] for hop in range(self.max_hops): m_A = self.m_story[hop] m_A = m_A[:batch_size] if(len(list(u[-1].size())) == 1): u[-1] = u[-1].unsqueeze(0) # used for bsz = 1. u_temp = u[-1].unsqueeze(1).expand_as(m_A) prob_lg = torch.sum(m_A * u_temp, 2) prob_p = self.log_softmax(prob_lg) m_C = self.m_story[hop + 1] m_C = m_C[:batch_size] prob = prob_p.unsqueeze(2).expand_as(m_C) o_k = torch.sum(m_C * prob, 1) if (hop == 0): p_vocab = self.W1(torch.cat((u[0], o_k), 1)) prob_v = self.log_softmax(p_vocab) u_k = u[-1] + o_k u.append(u_k) p_ptr = prob_lg if is_training: # p_ptr, p_vocab 是 softmax 之前的值, 不是概率 return p_vocab, state, output else: return prob_v, state, output
def collect_metrics(self, outputs, target): num_samples = target.size(0) metrics = Pack(num_samples=num_samples) loss = 0 logits = outputs.logits nll = self.nll_loss(logits, target) num_words = target.ne(self.padding_idx).sum().item() acc = accuracy(logits, target, padding_idx=self.padding_idx) metrics.add(nll=(nll, num_words), acc=acc) loss += nll metrics.add(loss=loss) return metrics
def forward(self, src_inputs, pos_tgt_inputs, neg_tgt_inputs, src_hidden=None, tgt_hidden=None): outputs = Pack() src_hidden = self.src_encoder(src_inputs, src_hidden)[1][-1] if self.with_project: src_hidden = self.project(src_hidden) pos_tgt_hidden = self.tgt_encoder(pos_tgt_inputs, tgt_hidden)[1][-1] neg_tgt_hidden = self.tgt_encoder(neg_tgt_inputs, tgt_hidden)[1][-1] pos_logits = (src_hidden * pos_tgt_hidden).sum(dim=-1) neg_logits = (src_hidden * neg_tgt_hidden).sum(dim=-1) outputs.add(pos_logits=pos_logits, neg_logits=neg_logits) return outputs
def encode(self, inputs, hidden=None, is_training=False): """ encode """ ''' inputs: src, topic_src, topic_tgt, [tgt] ''' outputs = Pack() enc_inputs = _, lengths = inputs.src[0][:, 1:-1], inputs.src[1] - 2 enc_outputs, enc_hidden = self.encoder(enc_inputs, hidden) if self.with_bridge: enc_hidden = self.bridge(enc_hidden) guide_score = enc_hidden[-1] decoder_init_state = enc_hidden bow_src = inputs.bow ntm_stat = self.ntm(bow_src) outputs.add(ntm_loss=ntm_stat['loss']) # obtain topic words _, tw_indices = self.ntm.get_topics().topk(self.topic_k, dim=1) # K * k src_labels = F.one_hot(inputs.topic_src_label, num_classes=self.topic_num) # B * K tgt_labels = F.one_hot(inputs.topic_tgt_label, num_classes=self.topic_num) src_words = src_labels.float() @ tw_indices.float() # B * k src_words = src_words.detach().long() tgt_words = tgt_labels.float() @ tw_indices.float() tgt_words = tgt_words.detach().long() # only src topic word src_outputs = self.topic_layer(src_words) # b * k * h # only tgt topic word tgt_outputs = self.topic_layer(tgt_words) # b * k * h dec_init_state = self.decoder.initialize_state( hidden=decoder_init_state, attn_memory=enc_outputs, memory_lengths=lengths, guide_score=guide_score, src_memory=src_outputs, tgt_memory=tgt_outputs, ) return outputs, dec_init_state
def decode(self, input, state, is_training=False): """ decode """ hidden = state.hidden input_feed= state.input_feed output = Pack() if self.embedder is not None: input = self.embedder(input) # shape: (batch_size, 1, input_size) input = input.unsqueeze(1) input_feed = input_feed rnn_input = torch.cat([input,input_feed] ,dim=-1) rnn_output, new_hidden = self.rnn(rnn_input, hidden) attn_memory = state.attn_memory query = new_hidden[0][-1].unsqueeze(1) weighted_context, attn = self.attention(query=query, memory=attn_memory, mask=state.mask.eq(0)) final_output=torch.cat([query, weighted_context], dim=-1) # fusion_sigmod=self.fusion(final_output) # # fusion_hidden=fusion_sigmod*weighted_context+(1-fusion_sigmod)*query final_output=self.fc1(final_output) output.add(attn=attn) state.hidden = list(new_hidden) state.input_feed=final_output # state.input_feed=fusion_hidden out_input=final_output if is_training: return out_input, state, output else: log_prob = self.output_layer(out_input) return log_prob, state, output
def collect_metrics(self, outputs, target, ptr_index, kb_index): """ collect_metrics """ num_samples = target.size(0) metrics = Pack(num_samples=num_samples) loss = 0 # loss for generation logits = outputs.logits nll = self.nll_loss(logits, target) loss += nll ''' # loss for gate pad_zeros = torch.zeros([num_samples, 1], dtype=torch.long) if self.use_gpu: pad_zeros = pad_zeros.cuda() ptr_index = torch.cat([ptr_index, pad_zeros], dim=-1).float() gate_logits = outputs.gate_logits loss_gate = self.bce_loss(gate_logits, ptr_index) loss += loss_gate ''' # loss for selector # selector_target = kb_index.float() # selector_logits = outputs.selector_logits # selector_mask = outputs.selector_mask # # if selector_target.size(-1) < selector_logits.size(-1): # pad_zeros = torch.zeros(size=(num_samples, selector_logits.size(-1)-selector_target.size(-1)), # dtype=torch.float) # if self.use_gpu: # pad_zeros = pad_zeros.cuda() # selector_target = torch.cat([selector_target, pad_zeros], dim=-1) # loss_ptr = self.bce_loss(selector_logits, selector_target, mask=selector_mask) loss_ptr = torch.tensor(0.0) if self.use_gpu: loss_ptr = loss_ptr.cuda() loss += loss_ptr acc = accuracy(logits, target, padding_idx=self.padding_idx) metrics.add(loss=loss, ptr=loss_ptr, acc=acc, logits=logits, prob=outputs.prob) return metrics
def interact(self, src, cue=None): if src == "": return None inputs = Pack() src = self.src_field.numericalize([src]) inputs.add(src=list2tensor(src)) if cue is not None: cue = self.cue_field.numericalize([cue]) inputs.add(cue=list2tensor(cue)) if self.use_gpu: inputs = inputs.cuda() _, preds, _, _ = self.forward(inputs=inputs, num_candidates=1) pred = self.tgt_field.denumericalize(preds[0][0]) return pred
def decode(self, input, state, is_training=False): """ decode """ hidden = state.hidden rnn_input_list = [] out_input_list = [] output = Pack() if self.embedder is not None: input = self.embedder(input) # shape: (batch_size, 1, input_size) input = input.unsqueeze(1) rnn_input_list.append(input) if self.feature_size is not None: feature = state.feature.unsqueeze(1) rnn_input_list.append(feature) if self.attn_mode is not None: attn_memory = state.attn_memory attn_mask = state.attn_mask query = hidden[-1].unsqueeze(1) weighted_context, attn = self.attention(query=query, memory=attn_memory, mask=attn_mask) rnn_input_list.append(weighted_context) out_input_list.append(weighted_context) output.add(attn=attn) rnn_input = torch.cat(rnn_input_list, dim=-1) rnn_output, new_hidden = self.rnn(rnn_input, hidden) out_input_list.append(rnn_output) out_input = torch.cat(out_input_list, dim=-1) state.hidden = new_hidden if is_training: return out_input, state, output else: log_prob = self.output_layer(out_input) return log_prob, state, output
def collect_metrics(self, outputs, target, bridge=None, epoch=-1): """ collect_metrics """ num_samples = target.size(0) metrics = Pack(num_samples=num_samples) loss = 0 # response generation logits = outputs.logits nll = self.nll_loss(logits, target) num_words = target.ne(self.padding_idx).sum().item() acc = accuracy(logits, target, padding_idx=self.padding_idx) metrics.add(nll=(nll, num_words), acc=acc) loss += nll # neural topic model ntm_loss = outputs.ntm_loss.sum().item() loss += ntm_loss / self.topic_vocab_size * 0.3 metrics.add(loss=loss) return metrics
def collect_metrics(self, outputs, target, epoch=-1): """ collect_metrics """ num_samples = target.size(0) metrics = Pack(num_samples=num_samples) loss = 0 # test begin # nll = self.nll(torch.log(outputs.posterior_attn+1e-10), outputs.attn_index) # loss += nll # attn_acc = attn_accuracy(outputs.posterior_attn, outputs.attn_index) # metrics.add(attn_acc=attn_acc) # metrics.add(loss=loss) # return metrics # test end logits = outputs.logits scores = -self.nll_loss(logits, target, reduction=False) nll_loss = self.nll_loss(logits, target) num_words = target.ne(self.padding_idx).sum().item() acc = accuracy(logits, target, padding_idx=self.padding_idx) metrics.add(nll=(nll_loss, num_words), acc=acc) # persona loss if 'attn_index' in outputs: attn_acc = attn_accuracy(outputs.cue_attn, outputs.attn_index) metrics.add(attn_acc=attn_acc) per_logits = torch.log(outputs.cue_attn + self.eps) # cue_attn(batch_size, sent_num) per_labels = outputs.attn_index ##(batch_size) use_per_loss = self.persona_loss( per_logits, per_labels) # per_labels(batch_size) metrics.add(use_per_loss=use_per_loss) loss += 0.7 * use_per_loss loss += 0.3 * nll_loss else: loss += nll_loss metrics.add(loss=loss) return metrics, scores
def collect_metrics(self, outputs, target, emo_target): """ collect_metrics """ num_samples = target[0].size(0) num_words = target[1].sum().item() metrics = Pack(num_samples=num_samples) target_len = target[1] mask = sequence_mask(target_len) mask = mask.float() # logits = outputs.logits # nll = self.nll_loss(logits, target) out_copy = outputs.out_copy # out_copy batch x max_len x src target_loss = out_copy.gather(2, target[0].unsqueeze(-1)).squeeze(-1) target_loss = target_loss * mask target_loss += 1e-15 target_loss = target_loss.log() loss = -((target_loss.sum()) / num_words) out_emo = outputs.logits # batch x max_len x dim batch_size, max_len, class_num = out_emo.size() # out_emo=out_emo.view(batch_size*max_len, class_num) # emo_target=emo_target.view(-1) target_emo_loss = out_emo.gather( 2, emo_target[0].unsqueeze(-1)).squeeze(-1) target_len -= 1 mask_ = sequence_mask(target_len) mask_ = mask_.float() new_mask = mask.data.new(batch_size, max_len).zero_() # print(mask.size()) # print(new_mask.size()) new_mask[:, :max_len - 1] = mask_ target_emo_loss = target_emo_loss * new_mask target_emo_loss += 1e-15 target_emo_loss = target_emo_loss.log() emo_loss = -((target_emo_loss.sum()) / num_words) metrics.add(loss=loss) metrics.add(emo_loss=emo_loss) # 这里,我们将只计算 acc = accuracy(out_copy, target[0], mask=mask) metrics.add(acc=acc) return metrics
def encode(self, inputs, hidden=None, is_training=False): """ encode """ outputs = Pack() enc_inputs = _, lengths = inputs.src[0][:, 1:-1], inputs.src[1] - 2 enc_outputs, enc_hidden = self.encoder(enc_inputs, hidden) if self.with_bridge: enc_hidden = self.bridge(enc_hidden) # knowledge batch_size, sent_num, sent = inputs.cue[0].size() tmp_len = inputs.cue[1] tmp_len[tmp_len > 0] -= 2 cue_inputs = inputs.cue[0].view(-1, sent)[:, 1:-1], tmp_len.view(-1) cue_enc_outputs, cue_enc_hidden = self.knowledge_encoder( cue_inputs, hidden) cue_outputs = cue_enc_hidden[-1].view(batch_size, sent_num, -1) # Attention weighted_cue, cue_attn = self.prior_attention( query=enc_hidden[-1].unsqueeze(1), memory=cue_outputs, mask=inputs.cue[1].eq(0)) cue_attn = cue_attn.squeeze(1) outputs.add(prior_attn=cue_attn) indexs = cue_attn.max(dim=1)[1] # hard attention if self.use_gs: knowledge = cue_outputs.gather(1, \ indexs.view(-1, 1, 1).repeat(1, 1, cue_outputs.size(-1))) else: knowledge = weighted_cue if self.use_posterior: tgt_enc_inputs = inputs.tgt[0][:, 1:-1], inputs.tgt[1] - 2 _, tgt_enc_hidden = self.knowledge_encoder(tgt_enc_inputs, hidden) posterior_weighted_cue, posterior_attn = self.posterior_attention( # P(z|u,r) # query=torch.cat([dec_init_hidden[-1], tgt_enc_hidden[-1]], dim=-1).unsqueeze(1) # P(z|r) query=tgt_enc_hidden[-1].unsqueeze(1), memory=cue_outputs, mask=inputs.cue[1].eq(0)) posterior_attn = posterior_attn.squeeze(1) outputs.add(posterior_attn=posterior_attn) # Gumbel Softmax if self.use_gs: gumbel_attn = F.gumbel_softmax(torch.log(posterior_attn + 1e-10), 0.1, hard=True) outputs.add(gumbel_attn=gumbel_attn) knowledge = torch.bmm(gumbel_attn.unsqueeze(1), cue_outputs) indexs = gumbel_attn.max(-1)[1] else: knowledge = posterior_weighted_cue indexs = posterior_attn.max(dim=1)[1] if self.use_bow: bow_logits = self.bow_output_layer(knowledge) outputs.add(bow_logits=bow_logits) if self.use_dssm: dssm_knowledge = self.dssm_project(knowledge) outputs.add(dssm=dssm_knowledge) outputs.add(reply_vec=tgt_enc_hidden[-1]) # neg sample neg_idx = torch.arange(enc_inputs[1].size(0)).type_as( enc_inputs[1]) neg_idx = (neg_idx + 1) % neg_idx.size(0) neg_tgt_enc_inputs = tgt_enc_inputs[0][ neg_idx], tgt_enc_inputs[1][neg_idx] _, neg_tgt_enc_hidden = self.knowledge_encoder( neg_tgt_enc_inputs, hidden) pos_logits = (enc_hidden[-1] * tgt_enc_hidden[-1]).sum(dim=-1) neg_logits = (enc_hidden[-1] * neg_tgt_enc_hidden[-1]).sum(dim=-1) outputs.add(pos_logits=pos_logits, neg_logits=neg_logits) elif is_training: if self.use_gs: gumbel_attn = F.gumbel_softmax(torch.log(cue_attn + 1e-10), 0.1, hard=True) knowledge = torch.bmm(gumbel_attn.unsqueeze(1), cue_outputs) indexs = gumbel_attn.max(-1)[1] else: knowledge = weighted_cue outputs.add(indexs=indexs) if 'index' in inputs.keys(): outputs.add(attn_index=inputs.index) if self.use_kd: knowledge = self.knowledge_dropout(knowledge) if self.weight_control: weights = (enc_hidden[-1] * knowledge.squeeze(1)).sum(dim=-1) weights = self.sigmoid(weights) # norm in batch # weights = weights / weights.mean().item() outputs.add(weights=weights) knowledge = knowledge * weights.view(-1, 1, 1).repeat( 1, 1, knowledge.size(-1)) dec_init_state = self.decoder.initialize_state( hidden=enc_hidden, attn_memory=enc_outputs if self.attn_mode else None, memory_lengths=lengths if self.attn_mode else None, knowledge=knowledge) return outputs, dec_init_state
def encode(self, inputs, hidden=None, is_training=False): """ encode """ ''' #inputs: 嵌套形式为{分离src和target和cue->(分离数据和长度->tensor数据值 #{'src':( 数据值-->shape(batch_size , sen_num , max_len), 句子长度值--> shape(batch_size,sen_num) ), 'tgt':( 数据值-->shape(batch_size , max_len), 句子长度值-->shape(batch_size) ), 'cue' :( 数据值-->shape(batch_size, max_len), 句子长度值--> shape(batch_size) ), 'label':( 数据值-->shape(batch_size , max_len), 句子长度值-->shape(batch_size) ), 'index': ( 数据值-->shape(batch_size , max_len), 句子长度值-->shape(batch_size) ) } ''' outputs = Pack() ''' 第二阶段''' if self.task_id == 1: enc_inputs = inputs.src[0][:, 1:-1], inputs.src[1] - 2 lengths = inputs.src[1] - 2 # (batch_size) enc_outputs, enc_hidden, enc_embedding = self.encoder( enc_inputs, hidden) # enc_outputs:(batch_size, max_len-2, 2*rnn_hidden_size) # enc_hidden:(num_layer , batch_size , 2*rnn_hidden_size) if self.with_bridge: enc_hidden = self.bridge(enc_hidden) # tem_bth,tem_len,tem_hi_size =enc_outputs.size()# batch_size, max_len-2, 2*rnn_hidden_size) key_index, len_key_index = inputs.index[0], inputs.index[ 1] # key_index(batch_size , idx_max_len) max_len = key_index.size(1) key_mask = sequence_mask(len_key_index, max_len).eq( 0) # key_mask(batch_size , idx_max_len) key_hidden = torch.gather( enc_embedding, 1, key_index.unsqueeze(-1).repeat(1, 1, enc_embedding.size( -1))) # (batch_size ,idx_max_len, 2*rnn_hidden_size) key_global = key_hidden.masked_fill( key_mask.unsqueeze(-1), 0.0).sum(1) / len_key_index.unsqueeze(1).float() key_global = self.key_linear( key_global) # (batch_size, 2*rnn_hidden_size) # persona_aware = torch.cat([key_global, enc_hidden[-1]], dim=-1) # (batch_size ,2*rnn_hidden_size) persona_aware = key_global + enc_hidden[ -1] #(batch_size , 2*rnn_hidden_size) # persona batch_size, sent_num, sent = inputs.cue[0].size() cue_len = inputs.cue[1] # (batch_size,sen_num) cue_len[cue_len > 0] -= 2 # (batch_size, sen_num) cue_inputs = inputs.cue[0].view(-1, sent)[:, 1:-1], cue_len.view(-1) # cue_inputs:((batch_size*sent_num , max_len-2),(batch_size*sent_num)) cue_enc_outputs, cue_enc_hidden, _ = self.persona_encoder( cue_inputs, hidden) # cue_enc_outputs:(batch_size*sent_num , max_len-2, 2*rnn_hidden_size) # cue_enc_hidden:(层数 , batch_size*sent_num, 2 * rnn_hidden_size) cue_outputs = cue_enc_hidden[-1].view(batch_size, sent_num, -1) cue_enc_outputs = cue_enc_outputs.view( batch_size, sent_num, cue_enc_outputs.size(1), -1 ) # cue_enc_outputs:(batch_size, sent_num , max_len-2, 2*rnn_hidden_size) cue_len = cue_len.view(batch_size, sent_num) # cue_outputs:(batch_size, sent_num, 2 * rnn_hidden_size) # Attention weighted_cue1, cue_attn1 = self.persona_attention( query=persona_aware.unsqueeze(1), memory=cue_outputs, mask=inputs.cue[1].eq(0)) # weighted_cue:(batch_size , 1 , 2 * rnn_hidden_size) persona_memory1 = weighted_cue1 + persona_aware.unsqueeze(1) weighted_cue2, cue_attn2 = self.persona_attention( query=persona_memory1, memory=cue_outputs, mask=inputs.cue[1].eq(0)) persona_memory2 = weighted_cue2 + persona_aware.unsqueeze(1) weighted_cue3, cue_attn3 = self.persona_attention( query=persona_memory2, memory=cue_outputs, mask=inputs.cue[1].eq(0)) cue_attn = cue_attn3.squeeze(1) # cue_attn:(batch_size, sent_num) outputs.add(cue_attn=cue_attn) indexs = cue_attn.max(dim=1)[1] # (batch_size) if is_training: # gumbel_attn = F.gumbel_softmax(torch.log(cue_attn + 1e-10), 0.1, hard=True) # persona = torch.bmm(gumbel_attn.unsqueeze(1), cue_outputs) # indexs = gumbel_attn.max(-1)[1] # cue_lengths = cue_len.gather(1, indexs.unsqueeze(1)).squeeze(1) # (batch_size) persona = cue_enc_outputs.gather( 1, indexs.view(-1, 1, 1, 1).repeat( 1, 1, cue_enc_outputs.size(2), cue_enc_outputs.size(3))).squeeze( 1) # (batch_size , max_len-2, 2*rnn_hidden_size) cue_lengths = cue_len.gather(1, indexs.unsqueeze(1)).squeeze( 1) # (batch_size) else: persona = cue_enc_outputs.gather( 1, indexs.view(-1, 1, 1, 1).repeat( 1, 1, cue_enc_outputs.size(2), cue_enc_outputs.size(3))).squeeze( 1) # (batch_size , max_len-2, 2*rnn_hidden_size) cue_lengths = cue_len.gather(1, indexs.unsqueeze(1)).squeeze( 1) # (batch_size) outputs.add(indexs=indexs) outputs.add(attn_index=inputs.label) # (batch_size) dec_init_state = self.decoder.initialize_state( hidden=enc_hidden, attn_memory=enc_outputs if self.attn_mode else None, memory_lengths=lengths if self.attn_mode else None, # (batch_size) cue_enc_outputs= persona, # (batch_size, max_len-2, 2*rnn_hidden_size) cue_lengths=cue_lengths, # (batch_size) task_id=self.task_id) # if 'index' in inputs.keys(): # outputs.add(attn_index=inputs.index) elif self.task_id == 0: ''' 第一阶段''' # enc_inputs:((batch_size,max_len-2), (batch_size-2))**src去头去尾 # hidden:None batch_size, sent_num, sent_len = inputs.src[0].size() src_lengths = inputs.src[1] # (batch_size,sent_num) src_lengths[src_lengths > 0] -= 2 # src_lengths(batch_size, sent_num) src_inputs = inputs.src[0].view( -1, sent_len)[:, 1:-1], src_lengths.view(-1) # src_inputs:((batch_size*sent_num , max_len-2),(batch_size*sent_num)) src_enc_outputs, enc_hidden, _ = self.encoder(src_inputs, hidden) if self.with_bridge: enc_hidden = self.bridge(enc_hidden) # src_enc_outputs:(batch_size*sent_num , max_len-2, 2*rnn_hidden_size) # enc_hidden:(层数 , batch_size*sent_num, 2 * rnn_hidden_size) src_outputs = torch.mean( enc_hidden.view(self.num_layers, batch_size, sent_num, -1), 2) # 池化 # src_outputs:(层数,batch_size, 2 * rnn_hidden_size) # persona:((batch_size,max_len-2), (batch_size))**persona的Tensor去头去尾 cue_inputs = inputs.cue[0][:, 1:-1], inputs.cue[1] - 2 cue_lengths = inputs.cue[1] - 2 # (batch_size) cue_enc_outputs, cue_enc_hidden, _ = self.persona_encoder( cue_inputs, hidden) # cue_enc_outputs:(batch_size, max_len-2, 2*rnn_hidden_size) # cue_enc_hidden:(num_layer , batch_size , 2*rnn_hidden_size) dec_init_state = self.decoder.initialize_state( hidden=src_outputs, attn_memory=src_enc_outputs.view( batch_size, sent_num, sent_len - 2, -1) if self.attn_mode else None, # (batch_size, sent_num , max_len-2, 2*rnn_hidden_size) memory_lengths=src_lengths if self.attn_mode else None, # (batch_size,sent_num) cue_enc_outputs= cue_enc_outputs, # (batch_size, max_len-2, 2*rnn_hidden_size) cue_lengths=cue_lengths, task_id=self.task_id # (batch_size) ) return outputs, dec_init_state
def iterate(self, turn_inputs, kb_inputs, optimizer=None, grad_clip=None, is_training=True, method="GAN", mask=False): """ iterate note: this function iterate in the whole model (muti-agent) instead of single sub_model """ if isinstance(optimizer, tuple): optimizerG, optimizerDB, optimizerDE = optimizer # clear all memory before the begin of a new batch computation for name, model in self.named_children(): if name.startswith("model_"): model.reset_memory() model.load_kb_memory(kb_inputs) # store the whole model (muti_agent)'s metric metrics_list_S, metrics_list_TB, metrics_list_TE = [], [], [] metrics_list_G, metrics_list_DB, metrics_list_DE = [], [], [] mask_list_S, length_list = [], [] # store the whole model (muti_agent)'s loss total_loss_DB, total_loss_DE, total_loss_G = 0, 0, 0 # use to compute final loss (sum of each agent's loss) per turn for the cumulated total_loss in a batch loss = Pack() # use to store kb_mask for three single model kd_masks = Pack() # compare evaluation metric (bleu/f1score) among models if method in ('1-3', 'GAN'): # TODO complete bleu_ENS_gt_S, bleu_ENS_gt_TB, f1score_ENS_gt_TE = True, True, True else: # compute bleu_S_gt_TB per batch (compute metric for the following training batch) # (key: batch/following/training) res_bleu = self.compare_metric(generator_1=self.generator_S, generator_2=self.generator_TB, turn_inputs=turn_inputs, kb_inputs=kb_inputs, type='bleu', data_name=self.data_name) if isinstance(res_bleu, tuple): bleu_S_gt_TB, bleu_S_gt_TB_str = res_bleu else: assert isinstance(res_bleu, bool) bleu_S_gt_TB, bleu_S_gt_TB_str = res_bleu, '' if self.model_TE is not None: res_f1score = self.compare_metric( generator_1=self.generator_S, generator_2=self.generator_TE, turn_inputs=turn_inputs, kb_inputs=kb_inputs, type='f1score', data_name=self.data_name) if isinstance(res_f1score, tuple): f1score_S_gt_TE, f1score_S_gt_TE_str = res_f1score else: assert isinstance(res_f1score, bool) f1score_S_gt_TE, f1score_S_gt_TE_str = res_f1score, '' """ update discriminator """ # clear all memory again because of cumulation of the memory in the computation of the above generator for name, model in self.named_children(): if name.startswith("model_"): model.reset_memory() model.load_kb_memory(kb_inputs) # begin iterate (a dialogue batch) for i, inputs in enumerate(turn_inputs): for name, model in self.named_children(): if name.startswith("model_"): if model.use_gpu: inputs = inputs.cuda() src, src_lengths = inputs.src tgt, tgt_lengths = inputs.tgt task_label = inputs.task gold_entity = inputs.gold_entity ptr_index, ptr_lengths = inputs.ptr_index kb_index, kb_index_lengths = inputs.kb_index enc_inputs = src[:, 1: -1], src_lengths - 2 # filter <bos> <eos> dec_inputs = tgt[:, :-1], tgt_lengths - 1 # filter <eos> target = tgt[:, 1:] # filter <bos> target_mask = sequence_mask(tgt_lengths - 1) kd_mask = sequence_kd_mask(tgt_lengths - 1, target, name, self.ent_idx, self.nen_idx) outputs = model.forward(enc_inputs, dec_inputs) metrics = model.collect_metrics(outputs, target, ptr_index, kb_index) if name == "model_S": metrics_list_S.append(metrics) elif name == "model_TB": metrics_list_TB.append(metrics) else: metrics_list_TE.append(metrics) kd_masks[name] = kd_mask if mask else target_mask loss[name] = metrics model.update_memory( dialog_state_memory=outputs.dialog_state_memory, kb_state_memory=outputs.kb_state_memory) # store necessary data for three single model if self.model_TE is not None: kd_mask_e = kd_masks.model_TE kd_mask_s = kd_masks.model_S kd_mask_b = kd_masks.model_TB mask_list_S.append(kd_mask_s) length_list.append(tgt_lengths - 1) assert False not in (kd_mask_b == kd_mask_e) errD_B = self.discriminator_update(netD=self.discriminator_B, real_data=loss.model_TB.prob, fake_data=loss.model_S.prob, lengths=tgt_lengths - 1, mask=kd_mask_b) errD_E = self.discriminator_update(netD=self.discriminator_E, real_data=loss.model_TE.prob, fake_data=loss.model_S.prob, lengths=tgt_lengths - 1, mask=kd_mask_e) # collect discriminator‘s total loss metrics_DB = Pack(num_samples=metrics.num_samples) metrics_DE = Pack(num_samples=metrics.num_samples) metrics_DB.add(loss=errD_B, logits=0.0, prob=0.0) metrics_DE.add(loss=errD_E, logits=0.0, prob=0.0) metrics_list_DB.append(metrics_DB) metrics_list_DE.append(metrics_DE) # update in a batch total_loss_DB = total_loss_DB + errD_B total_loss_DE = total_loss_DE + errD_E loss.clear() kd_masks.clear() # check loss if torch.isnan(total_loss_DB) or torch.isnan(total_loss_DE): raise ValueError("NAN loss encountered!") # compute and update gradient if is_training: assert not None in (optimizerDB, optimizerDE) optimizerDB.zero_grad() optimizerDE.zero_grad() total_loss_DB.backward() total_loss_DE.backward() if grad_clip is not None and grad_clip > 0: torch.nn.utils.clip_grad_norm_( parameters=self.discriminator_B.parameters(), max_norm=grad_clip) torch.nn.utils.clip_grad_norm_( parameters=self.discriminator_E.parameters(), max_norm=grad_clip) optimizerDB.step() optimizerDE.step() """ update generator """ # begin iterate (a dialogue batch) n_turn = len(metrics_list_S) assert n_turn == len(turn_inputs) == len(mask_list_S) for i in range(n_turn): errG, errG_B, errG_E, nll = self.generator_update( netG=self.model_S, netDB=self.discriminator_B, netDE=self.discriminator_E, fake_data=metrics_list_S[i].prob, length=length_list[i], mask=mask_list_S[i], nll=metrics_list_S[i].loss, lambda_g=self.lambda_g) # collect generator‘s total loss metrics_G = Pack(num_samples=metrics_list_S[i].num_samples) metrics_G.add(loss=errG, loss_gb=errG_B, loss_ge=errG_E, loss_nll=nll, logits=0.0, prob=0.0) metrics_list_G.append(metrics_G) # update in a batch total_loss_G += errG # check loss if torch.isnan(total_loss_G): raise ValueError("NAN loss encountered!") # compute and update gradient if is_training: assert optimizerG is not None optimizerG.zero_grad() total_loss_G.backward() if grad_clip is not None and grad_clip > 0: torch.nn.utils.clip_grad_norm_( parameters=self.model_S.parameters(), max_norm=grad_clip) optimizerG.step() return metrics_list_S, metrics_list_G, metrics_list_DB, metrics_list_DE
def decode(self, input, state, is_training=False): """ decode """ hidden = state.hidden rnn_input_list = [] cue_input_list = [] out_input_list = [] output = Pack() if self.embedder is not None: input = self.embedder(input) # shape: (batch_size, 1, input_size) input = input.unsqueeze(1) rnn_input_list.append(input) cue_input_list.append(state.knowledge) if self.feature_size is not None: feature = state.feature.unsqueeze(1) rnn_input_list.append(feature) cue_input_list.append(feature) if self.attn_mode is not None: weighted_context, attn = self.attention( query=hidden[-1].unsqueeze(1), memory=state.src_enc_outputs, mask=state.src_mask) rnn_input_list.append(weighted_context) cue_input_list.append(weighted_context) out_input_list.append(weighted_context) output.add(attn=attn) rnn_input = torch.cat(rnn_input_list, dim=-1) rnn_output, rnn_hidden = self.rnn(rnn_input, hidden) cue_input = torch.cat(cue_input_list, dim=-1) cue_output, cue_hidden = self.cue_rnn(cue_input, hidden) h_y = self.tanh(self.fc1(rnn_hidden)) h_cue = self.tanh(self.fc2(cue_hidden)) if self.concat: new_hidden = self.fc3(torch.cat([h_y, h_cue], dim=-1)) else: k = self.sigmoid(self.fc3(torch.cat([h_y, h_cue], dim=-1))) new_hidden = k * h_y + (1 - k) * h_cue out_input_list.append(new_hidden[-1].unsqueeze(1)) out_input = torch.cat(out_input_list, dim=-1) prob = self.output_layer(out_input) if self.copy: batch_size, sent_num, sent, _ = state.cue_enc_outputs.size() _, knowledge_attn = self.attention( query=hidden[-1].unsqueeze(1).repeat(sent_num, 1, 1), memory=state.cue_enc_outputs.view(batch_size * sent_num, sent, -1), mask=state.cue_mask.view(batch_size * sent_num, -1)) knowledge_attn = state.cue_attn.unsqueeze( 2) * knowledge_attn.squeeze(1).view(batch_size, sent_num, -1) knowledge_attn = knowledge_attn.view(batch_size, 1, -1) output.add(knowledge_attn=knowledge_attn) p = F.softmax(self.fc4( torch.cat([ input, new_hidden[-1].unsqeeze(1), weighted_context, state.knowledge ], dim=-1)), dim=-1) output.add(p=p) p = p.split(1, dim=2) prob = (p[0] * prob).scatter_add(2, state.src_inputs.unsqueeze(1), p[1] * attn) prob = prob.scatter_add(2, state.cue_inputs.view(batch_size, 1, -1), p[2] * knowledge_attn) log_prob = torch.log(prob + 1e-10) state.hidden = new_hidden return log_prob, state, output
def decode(self, input, state, is_training=False ): # 这里是每一个时间步执行一次,注意这里batch_size特指有效长度,即当前时间步无padding的样本数 """ decode """ # hidden: src_outputs:(层数, batch_size, 2 * rnn_hidden_size) hidden = state.hidden task_id = state.task_id rnn_input_list = [] cue_input_list = [] out_input_list = [] # 为decoder的输出层做准备 output = Pack() if self.embedder is not None: input = self.embedder(input) # (batch_size,input_size) input = input.unsqueeze(1) # (batch_size , 1 , input_size) rnn_input_list.append(input) # persona = state.cue_enc_outputs # persona:(batch_size, 1 , 2*rnn_hidden_size)这里的persona是加权和后的persona上下文 if self.feature_size is not None: feature = state.feature.unsqueeze(1) rnn_input_list.append(feature) cue_input_list.append(feature) # 对enc_hidden作attention if self.attn_mode is not None: # 第二阶段 if task_id == 1: attn_memory = state.attn_memory # (batch_size , max_len-2, 2*rnn_hidden_size) attn_mask = state.attn_mask query = hidden[-1].unsqueeze( 1) # (batch_size, 1, 2*rnn_hidden_size) weighted_context, attn = self.attention( query=query, memory=attn_memory, mask=attn_mask ) #attn_mask(batch_size, num_enc_inputs) weighted_context(batch_size,1, 2*rnn_hidden_size) # 第一阶段 elif task_id == 0: ''' 分别对3个相似query做attention''' attn_memory = state.attn_memory # (batch_size,sent_num , max_len-2, 2*rnn_hidden_size) batch_size, sent_num, sent_len = attn_memory.size( 0), attn_memory.size(1), attn_memory.size(2) attn_memory = attn_memory.view( batch_size * sent_num, sent_len, -1) # (batch_size*sent_num , max_len-2, 2*rnn_hidden_size) attn_mask = state.attn_mask.view( batch_size * sent_num, -1 ) # attn_mask(batch_size*sent_num, max_len-2) 填充的0全部变成1,其他的变成0 query = hidden[-1].unsqueeze(1).repeat(1, sent_num, 1).view( batch_size * sent_num, 1, -1) # (batch_size*sent_num , 1, 2*rnn_hidden_size) weighted_context, attn = self.attention( query=query, memory=attn_memory, mask=attn_mask ) # weighted_context(batch_size*sent_num, 1 , 2*rnn_hidden_size) weighted_context = torch.mean( weighted_context.squeeze(1).view(batch_size, sent_num, -1), dim=1).unsqueeze( 1) # weighted_context(batch_size, 1,2*rnn_hidden_size) rnn_input_list.append(weighted_context) cue_input_list.append(weighted_context) out_input_list.append(weighted_context) output.add(attn=attn) ''' 对persona做attention''' cue_attn_memory = state.cue_enc_outputs # (batch_size, max_len-2, 2*rnn_hidden_size) cue_attn_mask = state.cue_attn_mask # (batch_size,max_len-2) cue_query = hidden[-1].unsqueeze( 1) # (batch_size, 1, 2*rnn_hidden_size) cue_weighted_context, cue_attn = self.per_word_attention( query=cue_query, memory=cue_attn_memory, mask=cue_attn_mask) # cue_weighted_context(batch_size, 1, 2*rnn_hidden_size) # cue_attn((batch_size, 1, memory_size)) cue_input_list.append(cue_weighted_context) # out_input_list.append(cue_weighted_context) output.add(cue_attn=cue_attn) rnn_input = torch.cat( rnn_input_list, dim=-1 ) # rnn_input(batch_size, 1 , input_size + 2*rnn_hidden_size + 2*rnn_hidden_size) rnn_output, rnn_hidden = self.rnn( rnn_input, hidden) # rnn_hidden(层数, batch_size , 2*rnn_hidden_size) cue_input = torch.cat(cue_input_list, dim=-1) #(batch_size, 1 , 4*rnn_hidden_size) cue_output, cue_hidden = self.cue_rnn( cue_input, hidden) #cue_hidden(1, batch_size , 2*rnn_hidden_size) h_y = self.tanh(self.fc1(rnn_hidden)) h_cue = self.tanh(self.fc2(cue_hidden)) if self.concat: new_hidden = self.fc3(torch.cat( [h_y, h_cue], dim=-1)) #(1, batch_size , 2*rnn_hidden_size) else: k = self.sigmoid(self.fc3(torch.cat([h_y, h_cue], dim=-1))) new_hidden = k * h_y + (1 - k) * h_cue state.hidden = new_hidden # (层数, batch_size , 2*rnn_hidden_size)为下一个时间步更新hidden out_input_list.append( new_hidden[-1].unsqueeze(1)) # (batch_size, 1 , 2*rnn_hidden_size) out_input = torch.cat( out_input_list, dim=-1 ) # (batch_size, 1 , 4*rnn_hidden_size)这里是要输入给为decoder的输出层的,相当于c+h if is_training: return out_input, state, output # out_input: 要输入给为decoder的输出层; state:decoder隐层状态; output:一个pack字典,包含key"attn" else: # 一个时间步 #out_input(batch_size, 1 , 4*rnn_hidden_size)这里是要输入给为decoder的输出层的,相当于c+h log_prob = self.output_layer(out_input) return log_prob, state, output
def collect_metrics(self, outputs, target, epoch=-1): """ collect_metrics """ num_samples = target.size(0) metrics = Pack(num_samples=num_samples) loss = 0 # test begin # nll = self.nll(torch.log(outputs.posterior_attn+1e-10), outputs.attn_index) # loss += nll # attn_acc = attn_accuracy(outputs.posterior_attn, outputs.attn_index) # metrics.add(attn_acc=attn_acc) # metrics.add(loss=loss) # return metrics # test end logits = outputs.logits scores = -self.nll_loss(logits, target, reduction=False) nll_loss = self.nll_loss(logits, target) num_words = target.ne(self.padding_idx).sum().item() acc = accuracy(logits, target, padding_idx=self.padding_idx) metrics.add(nll=(nll_loss, num_words), acc=acc) if self.use_posterior: kl_loss = self.kl_loss(torch.log(outputs.prior_attn + 1e-10), outputs.posterior_attn.detach()) metrics.add(kl=kl_loss) if self.use_bow: bow_logits = outputs.bow_logits bow_labels = target[:, :-1] bow_logits = bow_logits.repeat(1, bow_labels.size(-1), 1) bow = self.nll_loss(bow_logits, bow_labels) loss += bow metrics.add(bow=bow) if self.use_dssm: mse = self.mse_loss(outputs.dssm, outputs.reply_vec.detach()) loss += mse metrics.add(mse=mse) pos_logits = outputs.pos_logits pos_target = torch.ones_like(pos_logits) neg_logits = outputs.neg_logits neg_target = torch.zeros_like(neg_logits) pos_loss = F.binary_cross_entropy_with_logits(pos_logits, pos_target, reduction='none') neg_loss = F.binary_cross_entropy_with_logits(neg_logits, neg_target, reduction='none') loss += (pos_loss + neg_loss).mean() metrics.add(pos_loss=pos_loss.mean(), neg_loss=neg_loss.mean()) if epoch == -1 or epoch > self.pretrain_epoch or \ (self.use_bow is not True and self.use_dssm is not True): loss += nll_loss loss += kl_loss if self.use_pg: posterior_probs = outputs.posterior_attn.gather( 1, outputs.indexs.view(-1, 1)) reward = -perplexity(logits, target, self.weight, self.padding_idx) * 100 pg_loss = -(reward.detach() - self.baseline) * posterior_probs.view(-1) pg_loss = pg_loss.mean() loss += pg_loss metrics.add(pg_loss=pg_loss, reward=reward.mean()) if 'attn_index' in outputs: attn_acc = attn_accuracy(outputs.posterior_attn, outputs.attn_index) metrics.add(attn_acc=attn_acc) else: loss += nll_loss metrics.add(loss=loss) return metrics, scores
def reward_fn1(self, preds, targets, gold_ents, ptr_index, task_label): """ reward_fn1 General reward """ # parameters alpha1 = 1.0 alpha2 = 0.3 # acc reward ''' # get the weighted mask no_padding_mask = preds.ne(self.padding_idx).float() trues = (preds == targets).float() if self.padding_idx is not None: weights = no_padding_mask acc = (weights * trues).sum(dim=1) / weights.sum(dim=1) else: acc = trues.mean(dim=1) ''' pred_text = self.tgt_field.denumericalize(preds) tgt_text = self.tgt_field.denumericalize(targets) batch_size = targets.size(0) batch_kb_inputs = self.kbs[:batch_size, :, :] kb_plain = self.kb_field.denumericalize(batch_kb_inputs) result = Pack() result.add(pred_text=pred_text, tgt_text=tgt_text, gold_ents=gold_ents, kb_plain=kb_plain) result_list = result.flatten() # bleu reward bleu_score = [] for res in result_list: hyp_toks = res.pred_text.split() ref_toks = res.tgt_text.split() try: bleu_1 = sentence_bleu(references=[ref_toks], hypothesis=hyp_toks, smoothing_function=SmoothingFunction().method7, weights=[1, 0, 0, 0]) except: bleu_1 = 0 try: bleu_2 = sentence_bleu(references=[ref_toks], hypothesis=hyp_toks, smoothing_function=SmoothingFunction().method7, weights=[0.5, 0.5, 0, 0]) except: bleu_2 = 0 bleu = (bleu_1 + bleu_2) / 2 bleu_score.append(bleu) bleu_score = torch.tensor(bleu_score, dtype=torch.float) # entity f1 reward f1_score = [] report_f1 = [] for res in result_list: if len(res.gold_ents) == 0: f1_pred = 1.0 else: # TODO: change the way #gold_entity = ' '.join(res.gold_ents).replace('_', ' ').split() #pred_sent = res.pred_text.replace('_', ' ') gold_entity = res.gold_ents pred_sent = res.pred_text f1_pred, _ = compute_prf(gold_entity, pred_sent, global_entity_list=[], kb_plain=res.kb_plain) report_f1.append(f1_pred) f1_score.append(f1_pred) if len(report_f1) == 0: report_f1.append(0.0) f1_score = torch.tensor(f1_score, dtype=torch.float) report_f1 = torch.tensor(report_f1, dtype=torch.float) if self.use_gpu: bleu_score = bleu_score.cuda() f1_score = f1_score.cuda() report_f1 = report_f1.cuda() # compound reward #reward = alpha1 * bleu_score.unsqueeze(-1) + alpha2 * f1_score.unsqueeze(-1) reward = alpha1 * bleu_score.unsqueeze(-1) return reward, bleu_score, report_f1
def encode(self, inputs, hidden=None, is_training=False): """ encode """ # inputs就是一个batch的数据{'src':batch_size条,'tgt':batch_size条,'cue':batch_size条} outputs = Pack() enc_inputs = _, lengths = inputs.src[0][:, 1:-1], inputs.src[ 1] - 2 # 在field.py中str2num的时候,在每个句子前后都会加bos,eos enc_outputs, enc_hidden = self.encoder(enc_inputs, hidden) # enc_inputs: (batch_size, seq_len) # enc_output: (batch_size, seq_len, num_directions * hidden_size # enc_hidden: (num_layers * num_directions, batch_size, hidden_size)->(num_layers, batch_size, num_directions * hidden_size) # 此处(1, batch_size, num_directions*hidden_size)取[-1]变成(batch_size,num_directions*hidden_size) # 这里由于rnn_encoder.py的实现,2*hidden_size = config.hidden_size if self.with_bridge: enc_hidden = self.bridge(enc_hidden) # knowledge batch_size, sent_num, sent = inputs.cue[0].size( ) # cue[0] for knowledge content, 3D tmp_len = inputs.cue[1] tmp_len[tmp_len > 0] -= 2 # 去掉bos, eos cue_inputs = inputs.cue[0].view(-1, sent)[:, 1:-1], tmp_len.view( -1) # 1:-1去掉bos, eos cue_enc_outputs, cue_enc_hidden = self.knowledge_encoder( cue_inputs, hidden) cue_outputs = cue_enc_hidden[-1].view( batch_size, sent_num, -1) # cue_enc_hidden[-1]每条knowledge的表示, cue比src, tgt多一维 # Attention weighted_cue, cue_attn = self.prior_attention( query=enc_hidden[-1].unsqueeze(1), memory=cue_outputs, mask=inputs.cue[1].eq(0)) cue_attn = cue_attn.squeeze(1) outputs.add(prior_attn=cue_attn) indexs = cue_attn.max(dim=1)[1] # hard attention 取max值 if self.use_gs: knowledge = cue_outputs.gather(1, \ indexs.view(-1, 1, 1).repeat(1, 1, cue_outputs.size(-1))) else: knowledge = weighted_cue if self.use_posterior: # p(k|y) not p(k|x,y) tgt_enc_inputs = inputs.tgt[0][:, 1:-1], inputs.tgt[1] - 2 _, tgt_enc_hidden = self.knowledge_encoder(tgt_enc_inputs, hidden) posterior_weighted_cue, posterior_attn = self.posterior_attention( # P(z|u,r) # query=torch.cat([dec_init_hidden[-1], tgt_enc_hidden[-1]], dim=-1).unsqueeze(1) # P(z|r) query=tgt_enc_hidden[-1].unsqueeze(1), memory=cue_outputs, mask=inputs.cue[1].eq(0)) posterior_attn = posterior_attn.squeeze(1) outputs.add(posterior_attn=posterior_attn) # Gumbel Softmax if self.use_gs: gumbel_attn = F.gumbel_softmax(torch.log(posterior_attn + 1e-10), 0.1, hard=True) # 防止log内为0 outputs.add(gumbel_attn=gumbel_attn) knowledge = torch.bmm(gumbel_attn.unsqueeze(1), cue_outputs) indexs = gumbel_attn.max(-1)[1] else: knowledge = posterior_weighted_cue indexs = posterior_attn.max(dim=1)[1] if self.use_bow: bow_logits = self.bow_output_layer(knowledge) outputs.add(bow_logits=bow_logits) if self.use_dssm: dssm_knowledge = self.dssm_project(knowledge) outputs.add(dssm=dssm_knowledge) outputs.add(reply_vec=tgt_enc_hidden[-1]) # neg sample neg_idx = torch.arange(enc_inputs[1].size(0)).type_as( enc_inputs[1]) neg_idx = (neg_idx + 1) % neg_idx.size(0) neg_tgt_enc_inputs = tgt_enc_inputs[0][ neg_idx], tgt_enc_inputs[1][neg_idx] _, neg_tgt_enc_hidden = self.knowledge_encoder( neg_tgt_enc_inputs, hidden) pos_logits = (enc_hidden[-1] * tgt_enc_hidden[-1]).sum(dim=-1) neg_logits = (enc_hidden[-1] * neg_tgt_enc_hidden[-1]).sum(dim=-1) outputs.add(pos_logits=pos_logits, neg_logits=neg_logits) elif is_training: if self.use_gs: gumbel_attn = F.gumbel_softmax(torch.log(cue_attn + 1e-10), 0.1, hard=True) knowledge = torch.bmm(gumbel_attn.unsqueeze(1), cue_outputs) indexs = gumbel_attn.max(-1)[1] else: knowledge = weighted_cue outputs.add(indexs=indexs) if 'index' in inputs.keys(): outputs.add(attn_index=inputs.index) if self.use_kd: knowledge = self.knowledge_dropout(knowledge) if self.weight_control: # 给knowledge表示再加权,权重就是和enc_hidden[-1]的相似度 weights = (enc_hidden[-1] * knowledge.squeeze(1)).sum(dim=-1) weights = self.sigmoid(weights) # norm in batch # weights = weights / weights.mean().item() outputs.add(weights=weights) knowledge = knowledge * weights.view(-1, 1, 1).repeat( 1, 1, knowledge.size(-1)) dec_init_state = self.decoder.initialize_state( hidden=enc_hidden, attn_memory=enc_outputs if self.attn_mode else None, memory_lengths=lengths if self.attn_mode else None, knowledge=knowledge) return outputs, dec_init_state
def decode(self, input, state, is_training=False): """ decode """ hidden = state.hidden rnn_input_list = [] out_input_list = [] output = Pack() if self.embedder is not None: input = self.embedder(input) # shape: (batch_size, 1, input_size) input = input.unsqueeze(1) rnn_input_list.append(input) out_input_list.append(input) if self.attn_mode is not None: # (batch_size, 1, hidden_size) query = hidden[-1].unsqueeze(1) # history attention weighted_hist, attn_h = self.hist_attention(query=query, memory=state.attn_hist, mask=state.hist_mask) rnn_input_list.append(weighted_hist) out_input_list.append(weighted_hist) output.add(attn_h=attn_h) # fact attention weighted_fact, attn_f = self.fact_attention(query=query, memory=state.attn_fact, mask=state.fact_mask) rnn_input_list.append(weighted_fact) out_input_list.append(weighted_fact) output.add(attn_f=attn_f) rnn_input = torch.cat(rnn_input_list, dim=-1) rnn_output, new_hidden = self.rnn(rnn_input, hidden) out_input_list.append(rnn_output) # cat (fact_hidden, hist_hidden, hidden, x) # (batch_size, 1, out_input_size) out_input = torch.cat(out_input_list, dim=-1) state.hidden = new_hidden if is_training: return out_input, state, output else: p_mode = self.ff(out_input) # prob_hist = input.new_zeros( # size=(batch_size, 1, self.output_size), # dtype=torch.float) # prob_fact = input.new_zeros( # size=(batch_size, 1, self.output_size), # dtype=torch.float) prob_vocab = self.output_layer(out_input) weighted_prob = prob_vocab * p_mode[:, :, 0].unsqueeze(2) weighted_f = output.attn_f * p_mode[:, :, 1].unsqueeze(2) weighted_h = output.attn_h * p_mode[:, :, 2].unsqueeze(2) weighted_prob = convert_dist(weighted_h, state.hist, weighted_prob) weighted_prob = convert_dist(weighted_f, state.fact, weighted_prob) # a = torch.cat((prob_vocab, prob_hist, prob_fact), - # 1).view(batch_size * 1, self.output_size, -1) # b = p_mode.view(batch_size * 1, -1).unsqueeze(2) # prob = torch.bmm(a, b).squeeze().view(batch_size, 1, -1) log_prob = torch.log(weighted_prob + 1e-10) return log_prob, state, output