def collect_metrics(self, outputs): """ collect_metrics """ pos_logits = outputs.pos_logits pos_target = torch.ones_like(pos_logits) neg_logits = outputs.neg_logits neg_target = torch.zeros_like(neg_logits) pos_loss = F.binary_cross_entropy_with_logits(pos_logits, pos_target, reduction='none') neg_loss = F.binary_cross_entropy_with_logits(neg_logits, neg_target, reduction='none') loss = (pos_loss + neg_loss).mean() pos_acc = torch.sigmoid(pos_logits).gt(0.5).float().mean() neg_acc = torch.sigmoid(neg_logits).lt(0.5).float().mean() margin = (torch.sigmoid(pos_logits) - torch.sigmoid(neg_logits)).mean() metrics = Pack(loss=loss, pos_acc=pos_acc, neg_acc=neg_acc, margin=margin) num_samples = pos_target.size(0) metrics.add(num_samples=num_samples) return metrics
def decode(self, input, state, is_training=False): """ decode """ hidden = state.hidden rnn_input_list = [] cue_input_list = [] out_input_list = [] output = Pack() if self.embedder is not None: input = self.embedder(input) # shape: (batch_size, 1, input_size) input = input.unsqueeze(1) rnn_input_list.append(input) cue_input_list.append(state.knowledge) if self.feature_size is not None: feature = state.feature.unsqueeze(1) rnn_input_list.append(feature) cue_input_list.append(feature) if self.attn_mode is not None: attn_memory = state.attn_memory attn_mask = state.attn_mask query = hidden[-1].unsqueeze(1) weighted_context, attn = self.attention(query=query, memory=attn_memory, mask=attn_mask) rnn_input_list.append(weighted_context) cue_input_list.append(weighted_context) out_input_list.append(weighted_context) output.add(attn=attn) rnn_input = torch.cat(rnn_input_list, dim=-1) rnn_output, rnn_hidden = self.rnn(rnn_input, hidden) cue_input = torch.cat(cue_input_list, dim=-1) cue_output, cue_hidden = self.cue_rnn(cue_input, hidden) h_y = self.tanh(self.fc1(rnn_hidden)) h_cue = self.tanh(self.fc2(cue_hidden)) if self.concat: new_hidden = self.fc3(torch.cat([h_y, h_cue], dim=-1)) else: k = self.sigmoid(self.fc3(torch.cat([h_y, h_cue], dim=-1))) new_hidden = k * h_y + (1 - k) * h_cue #out_input_list.append(new_hidden.transpose(0, 1)) # bug fixed out_input_list.append(new_hidden[-1].unsqueeze(1)) out_input = torch.cat(out_input_list, dim=-1) state.hidden = new_hidden if is_training: return out_input, state, output else: log_prob = self.output_layer(out_input) return log_prob, state, output
def collate(data_list): """ collate """ batch = Pack() # 手写各个结构 # num_src batch['num_src'] = list2tensor([x['num_src'] for x in data_list]) # num_tgt_input batch['num_tgt_input'] = list2tensor( [x['num_tgt_input'] for x in data_list]) # tgt_output batch['tgt_output'] = list2tensor( [x['tgt_output'] for x in data_list]) batch['tgt_emo'] = list2tensor([x['tgt_emo'] for x in data_list]) # mask batch['mask'] = list2tensor([x['mask'] for x in data_list]) batch['raw_src'] = [x['raw_src'] for x in data_list] batch['raw_tgt'] = [x['raw_tgt'] for x in data_list] if 'id' in data_list[0].keys(): batch['id'] = [x['id'] for x in data_list] if device >= 0: batch = batch.cuda(device=device) return batch
def collate(data_list): batch = Pack() for key in data_list[0].keys(): batch[key] = list2tensor([x[key] for x in data_list]) if device >= 0: batch = batch.cuda(device=device) return batch
def collect_rl_metrics(self, sample_outputs, greedy_outputs, target, gold_entity, entity_dir): """ collect rl training metrics """ num_samples = target.size(0) rl_metrics = Pack(num_samples=num_samples) loss = 0 # log prob for sampling and greedily generation logits = sample_outputs.logits sample = sample_outputs.pred_word greedy_logits = greedy_outputs.logits greedy_sample = greedy_outputs.pred_word # cal reward sample_reward, _, _ = self.reward_fn(sample, target, gold_entity, entity_dir) greedy_reward, bleu_score, f1_score = self.reward_fn(greedy_sample, target, gold_entity, entity_dir) reward = sample_reward - greedy_reward # cal RL loss sample_log_prob = self.nll_loss(logits, sample, mask=sample.ne(self.padding_idx), reduction=False, matrix=False) # [batch_size, max_len] nll = sample_log_prob * reward.to(sample_log_prob.device) nll = getattr(torch, self.nll_loss.reduction)(nll.sum(dim=-1)) loss += nll # gen report rl_acc = accuracy(greedy_logits, target, padding_idx=self.padding_idx) if reward.dim() == 2: reward = reward.sum(dim=-1) rl_metrics.add(loss=loss, reward=reward.mean(), rl_acc=rl_acc, bleu_score=bleu_score.mean(), f1_score=f1_score.mean()) return rl_metrics
def encode(self, inputs, hidden=None, is_training=False): """ encode """ ''' inputs: src, topic_src, topic_tgt, [tgt] ''' outputs = Pack() enc_inputs = _, lengths = inputs.src[0][:, 1:-1], inputs.src[1] - 2 enc_outputs, enc_hidden = self.encoder(enc_inputs, hidden) if self.with_bridge: enc_hidden = self.bridge(enc_hidden) guide_score = enc_hidden[-1] decoder_init_state = enc_hidden if self.use_ntm: bow_src = inputs.bow ntm_stat = self.ntm(bow_src) outputs.add(ntm_loss=ntm_stat['loss']) embedding_weight = self.topic_embedder.weight topic_word_logit = self.ntm.topics.get_topic_word_logit() EPS = 1e-12 w = topic_word_logit.transpose(0, 1) # V x K nv = embedding_weight / ( torch.norm(embedding_weight, dim=1, keepdim=True) + EPS) # V x dim nw = w / (torch.norm(w, dim=0, keepdim=True) + EPS) # V x K t = nv.transpose(0, 1) @ w # dim x K t = t.transpose(0, 1) topic_feature_input = [] if 'S' in self.decoder_attention_channels: src_labels = F.one_hot(inputs.topic_src_label, num_classes=self.topic_num) # B * K src_topics = src_labels.float() @ t topic_feature_input.append(src_topics) if 'T' in self.decoder_attention_channels: tgt_labels = F.one_hot(inputs.topic_tgt_label, num_classes=self.topic_num) tgt_topics = tgt_labels.float() @ t topic_feature_input.append(tgt_topics) topic_feature = self.t_to_feature( torch.cat(topic_feature_input, dim=-1)) dec_init_state = self.decoder.initialize_state( hidden=decoder_init_state, attn_memory=enc_outputs, memory_lengths=lengths, guide_score=guide_score, topic_feature=topic_feature) return outputs, dec_init_state
def collate( data_list): # data_list的长度就是一个batch_size,每个元素都是__getitem__得到的 """ collate """ batch = Pack() for key in data_list[0].keys(): # keys(): src, tgt, cue batch[key] = list2tensor([x[key] for x in data_list ]) # 所有的src, tgt, cue分别整合在一起 if device >= 0: batch = batch.cuda(device=device) return batch
def collate(data_list): """ collate """ batch = Pack() # batch is a dict for key in data_list[0].keys(): # data_list: a list of dict # so one sample is one dict batch[key] = list2tensor([x[key] for x in data_list]) if device >= 0: batch = batch.cuda(device=device) return batch
def decode(self, input, state, is_training=False): last_hidden = state.hidden rnn_input_list = [] output = Pack() embed_q = self.C[0](input).unsqueeze(1) # b * e --> b * 1 * e batch_size = input.size(0) rnn_input_list.append(embed_q) if self.attn_mode is not None: attn_memory = state.attn_memory attn_mask = state.attn_mask query = last_hidden[-1].unsqueeze(1) weighted_context, attn = self.attention(query=query, memory=attn_memory, mask=attn_mask) rnn_input_list.append(weighted_context) output.add(attn=attn) rnn_input = torch.cat(rnn_input_list, dim=-1) output, new_hidden = self.gru(rnn_input, last_hidden) state.hidden = new_hidden u = [new_hidden[0].squeeze()] for hop in range(self.max_hops): m_A = self.m_story[hop] m_A = m_A[:batch_size] if(len(list(u[-1].size())) == 1): u[-1] = u[-1].unsqueeze(0) # used for bsz = 1. u_temp = u[-1].unsqueeze(1).expand_as(m_A) prob_lg = torch.sum(m_A * u_temp, 2) prob_p = self.log_softmax(prob_lg) m_C = self.m_story[hop + 1] m_C = m_C[:batch_size] prob = prob_p.unsqueeze(2).expand_as(m_C) o_k = torch.sum(m_C * prob, 1) if (hop == 0): p_vocab = self.W1(torch.cat((u[0], o_k), 1)) prob_v = self.log_softmax(p_vocab) u_k = u[-1] + o_k u.append(u_k) p_ptr = prob_lg if is_training: # p_ptr, p_vocab 是 softmax 之前的值, 不是概率 return p_vocab, state, output else: return prob_v, state, output
def forward(self, src_inputs, pos_tgt_inputs, neg_tgt_inputs, src_hidden=None, tgt_hidden=None): outputs = Pack() src_hidden = self.src_encoder(src_inputs, src_hidden)[1][-1] if self.with_project: src_hidden = self.project(src_hidden) pos_tgt_hidden = self.tgt_encoder(pos_tgt_inputs, tgt_hidden)[1][-1] neg_tgt_hidden = self.tgt_encoder(neg_tgt_inputs, tgt_hidden)[1][-1] pos_logits = (src_hidden * pos_tgt_hidden).sum(dim=-1) neg_logits = (src_hidden * neg_tgt_hidden).sum(dim=-1) outputs.add(pos_logits=pos_logits, neg_logits=neg_logits) return outputs
def encode(self, inputs, hidden=None, is_training=False): """ encode """ ''' inputs: src, topic_src, topic_tgt, [tgt] ''' outputs = Pack() enc_inputs = _, lengths = inputs.src[0][:, 1:-1], inputs.src[1] - 2 enc_outputs, enc_hidden = self.encoder(enc_inputs, hidden) if self.with_bridge: enc_hidden = self.bridge(enc_hidden) guide_score = enc_hidden[-1] decoder_init_state = enc_hidden bow_src = inputs.bow ntm_stat = self.ntm(bow_src) outputs.add(ntm_loss=ntm_stat['loss']) # obtain topic words _, tw_indices = self.ntm.get_topics().topk(self.topic_k, dim=1) # K * k src_labels = F.one_hot(inputs.topic_src_label, num_classes=self.topic_num) # B * K tgt_labels = F.one_hot(inputs.topic_tgt_label, num_classes=self.topic_num) src_words = src_labels.float() @ tw_indices.float() # B * k src_words = src_words.detach().long() tgt_words = tgt_labels.float() @ tw_indices.float() tgt_words = tgt_words.detach().long() # only src topic word src_outputs = self.topic_layer(src_words) # b * k * h # only tgt topic word tgt_outputs = self.topic_layer(tgt_words) # b * k * h dec_init_state = self.decoder.initialize_state( hidden=decoder_init_state, attn_memory=enc_outputs, memory_lengths=lengths, guide_score=guide_score, src_memory=src_outputs, tgt_memory=tgt_outputs, ) return outputs, dec_init_state
def encode(self, enc_inputs, hidden=None): """ encode """ outputs = Pack() enc_outputs, enc_hidden = self.encoder(enc_inputs, hidden) inputs, lengths = enc_inputs batch_size = enc_outputs.size(0) max_len = enc_outputs.size(1) attn_mask = sequence_mask(lengths, max_len).eq(0) if self.with_bridge: enc_hidden = self.bridge(enc_hidden) # insert dialog memory if self.dialog_state_memory is None: assert self.dialog_history_memory is None assert self.history_index is None assert self.memory_masks is None self.dialog_state_memory = enc_outputs self.dialog_history_memory = enc_outputs self.history_index = inputs self.memory_masks = attn_mask else: batch_state_memory = self.dialog_state_memory[:batch_size, :, :] self.dialog_state_memory = torch.cat([batch_state_memory, enc_outputs], dim=1) batch_history_memory = self.dialog_history_memory[:batch_size, :, :] self.dialog_history_memory = torch.cat([batch_history_memory, enc_outputs], dim=1) batch_history_index = self.history_index[:batch_size, :] self.history_index = torch.cat([batch_history_index, inputs], dim=-1) batch_memory_masks = self.memory_masks[:batch_size, :] self.memory_masks = torch.cat([batch_memory_masks, attn_mask], dim=-1) batch_kb_inputs = self.kbs[:batch_size, :, :] batch_kb_state_memory = self.kb_state_memory[:batch_size, :, :] batch_kb_slot_memory = self.kb_slot_memory[:batch_size, :, :] batch_kb_slot_index = self.kb_slot_index[:batch_size, :] kb_mask = self.kb_mask[:batch_size, :] selector_mask = self.selector_mask[:batch_size, :] # create batched KB inputs kb_memory, selector = self.decoder.initialize_kb(kb_inputs=batch_kb_inputs, enc_hidden=enc_hidden) # initialize decoder state dec_init_state = self.decoder.initialize_state( hidden=enc_hidden, state_memory=self.dialog_state_memory, history_memory=self.dialog_history_memory, kb_memory=kb_memory, kb_state_memory=batch_kb_state_memory, kb_slot_memory=batch_kb_slot_memory, history_index=self.history_index, kb_slot_index=batch_kb_slot_index, attn_mask=self.memory_masks, attn_kb_mask=kb_mask, selector=selector, selector_mask=selector_mask ) return outputs, dec_init_state
def decode(self, input, state, is_training=False): """ decode """ hidden = state.hidden input_feed= state.input_feed output = Pack() if self.embedder is not None: input = self.embedder(input) # shape: (batch_size, 1, input_size) input = input.unsqueeze(1) input_feed = input_feed rnn_input = torch.cat([input,input_feed] ,dim=-1) rnn_output, new_hidden = self.rnn(rnn_input, hidden) attn_memory = state.attn_memory query = new_hidden[0][-1].unsqueeze(1) weighted_context, attn = self.attention(query=query, memory=attn_memory, mask=state.mask.eq(0)) final_output=torch.cat([query, weighted_context], dim=-1) # fusion_sigmod=self.fusion(final_output) # # fusion_hidden=fusion_sigmod*weighted_context+(1-fusion_sigmod)*query final_output=self.fc1(final_output) output.add(attn=attn) state.hidden = list(new_hidden) state.input_feed=final_output # state.input_feed=fusion_hidden out_input=final_output if is_training: return out_input, state, output else: log_prob = self.output_layer(out_input) return log_prob, state, output
def collect_metrics(self, outputs, target, ptr_index, kb_index): """ collect_metrics """ num_samples = target.size(0) metrics = Pack(num_samples=num_samples) loss = 0 # loss for generation logits = outputs.logits nll = self.nll_loss(logits, target) loss += nll ''' # loss for gate pad_zeros = torch.zeros([num_samples, 1], dtype=torch.long) if self.use_gpu: pad_zeros = pad_zeros.cuda() ptr_index = torch.cat([ptr_index, pad_zeros], dim=-1).float() gate_logits = outputs.gate_logits loss_gate = self.bce_loss(gate_logits, ptr_index) loss += loss_gate ''' # loss for selector # selector_target = kb_index.float() # selector_logits = outputs.selector_logits # selector_mask = outputs.selector_mask # # if selector_target.size(-1) < selector_logits.size(-1): # pad_zeros = torch.zeros(size=(num_samples, selector_logits.size(-1)-selector_target.size(-1)), # dtype=torch.float) # if self.use_gpu: # pad_zeros = pad_zeros.cuda() # selector_target = torch.cat([selector_target, pad_zeros], dim=-1) # loss_ptr = self.bce_loss(selector_logits, selector_target, mask=selector_mask) loss_ptr = torch.tensor(0.0) if self.use_gpu: loss_ptr = loss_ptr.cuda() loss += loss_ptr acc = accuracy(logits, target, padding_idx=self.padding_idx) metrics.add(loss=loss, ptr=loss_ptr, acc=acc, logits=logits, prob=outputs.prob) return metrics
def decode(self, input, state, is_training=False): """ decode """ hidden = state.hidden rnn_input_list = [] out_input_list = [] output = Pack() if self.embedder is not None: input = self.embedder(input) # shape: (batch_size, 1, input_size) input = input.unsqueeze(1) rnn_input_list.append(input) if self.feature_size is not None: feature = state.feature.unsqueeze(1) rnn_input_list.append(feature) if self.attn_mode is not None: attn_memory = state.attn_memory attn_mask = state.attn_mask query = hidden[-1].unsqueeze(1) weighted_context, attn = self.attention(query=query, memory=attn_memory, mask=attn_mask) rnn_input_list.append(weighted_context) out_input_list.append(weighted_context) output.add(attn=attn) rnn_input = torch.cat(rnn_input_list, dim=-1) rnn_output, new_hidden = self.rnn(rnn_input, hidden) out_input_list.append(rnn_output) out_input = torch.cat(out_input_list, dim=-1) state.hidden = new_hidden if is_training: return out_input, state, output else: log_prob = self.output_layer(out_input) return log_prob, state, output
def generate(self,batch_iter, num_batches): self.model.eval() itoemo = ['NORM', 'POS', 'NEG'] with torch.no_grad(): results = [] for inputs in batch_iter: enc_inputs = inputs dec_inputs = inputs.num_tgt_input enc_outputs=Pack() outputs = self.model.forward(enc_inputs, dec_inputs, hidden=None) outputs=outputs.logits preds = outputs.max(dim=2) # news_id = inputs.id tgt_raw = inputs.raw_tgt preds = preds[1].tolist() temp_a_1=[] emo_b_1=[] temp = [] tgt_emo = inputs.tgt_emo[0].tolist() for a, b, c in zip( tgt_raw, preds, tgt_emo): # enc_outputs.add(preds=preds, scores=scores, emos=emos, target_emos=temp) # result_batch = enc_outputs.flatten() # results += result_batch a = a[1:] temp_a = [] emo_b = [] emo_c=[] for i, entity in enumerate(a): temp_a.append(entity) emo_b.append(itoemo[b[i]]) emo_c.append(itoemo[c[i]]) # tgt_raw=tgt_raw[1:] assert len(temp_a) == len(emo_b) assert len(emo_c) == len(emo_b) temp_a_1.append([temp_a]) #pred1 emo_b_1.append([emo_b]) # emo temp.append(emo_c) # temp = [] # tgt_emo = inputs.tgt_emo[0].tolist() # for item in tgt_emo: # temp.append([itoemo[x] for x in item]) # print(emo_b_1) # print(temp) if hasattr(inputs, 'id') and inputs.id is not None: enc_outputs.add(id=inputs['id']) enc_outputs.add(tgt=tgt_raw, preds=temp_a_1, emos=emo_b_1, target_emos=temp) result_batch=enc_outputs.flatten() results+=result_batch return results
def collect_metrics(self, outputs, target, epoch=-1): """ collect_metrics """ num_samples = target.size(0) metrics = Pack(num_samples=num_samples) loss = 0 # test begin # nll = self.nll(torch.log(outputs.posterior_attn+1e-10), outputs.attn_index) # loss += nll # attn_acc = attn_accuracy(outputs.posterior_attn, outputs.attn_index) # metrics.add(attn_acc=attn_acc) # metrics.add(loss=loss) # return metrics # test end logits = outputs.logits scores = -self.nll_loss(logits, target, reduction=False) nll_loss = self.nll_loss(logits, target) num_words = target.ne(self.padding_idx).sum().item() acc = accuracy(logits, target, padding_idx=self.padding_idx) metrics.add(nll=(nll_loss, num_words), acc=acc) # persona loss if 'attn_index' in outputs: attn_acc = attn_accuracy(outputs.cue_attn, outputs.attn_index) metrics.add(attn_acc=attn_acc) per_logits = torch.log(outputs.cue_attn + self.eps) # cue_attn(batch_size, sent_num) per_labels = outputs.attn_index ##(batch_size) use_per_loss = self.persona_loss( per_logits, per_labels) # per_labels(batch_size) metrics.add(use_per_loss=use_per_loss) loss += 0.7 * use_per_loss loss += 0.3 * nll_loss else: loss += nll_loss metrics.add(loss=loss) return metrics, scores
def collate(data_list): """ collate """ batch = Pack() for key in data_list[0].keys(): if key == 'topic': continue batch[key] = list2tensor([x[key] for x in data_list]) batch_bow = [] for x in data_list: v = torch.zeros(bow_vocab_size, dtype=torch.float) x_bow = x['topic'] # dict for w, f in x_bow: v[w] += f batch_bow.append(v) batch['bow'] = torch.stack(batch_bow) if device >= 0: batch = batch.cuda(device=device) return batch
def collate(data_list): """ collate --- data_list: List[Dict] """ batch = Pack() for key in data_list[0].keys(): batch[key] = list2tensor([x[key] for x in data_list]) if device >= 0: batch = batch.cuda(device=device) # copy mechanism prepare raw_src = [x['raw_src'].split() for x in data_list] token2idx, idx2token, batch_pos_idx_map, idx2idx_mapping \ = build_copy_mapping(raw_src, vocab) batch['token2idx'] = token2idx batch['idx2token'] = idx2token batch['batch_pos_idx_map'] = batch_pos_idx_map batch['idx2idx_mapping'] = idx2idx_mapping batch['output'] = '???' return batch
def encode(self, inputs, hidden=None): outputs = Pack() enc_inputs = _, lengths = inputs.src[0][:, 1:-1], inputs.src[1] - 2 enc_outputs, enc_hidden = self.encoder(enc_inputs, hidden) if self.with_bridge: enc_hidden = self.bridge(enc_hidden) dec_init_state = self.decoder.initialize_state( hidden=enc_hidden, attn_memory=enc_outputs if self.attn_mode else None, memory_lengths=lengths if self.attn_mode else None) return outputs, dec_init_state
def create_turn_batch(data_list): """ create_turn_batch """ turn_batches = [] for data_dict in data_list: batch = Pack() for key in data_dict.keys(): if key in ['src', 'tgt', 'ptr_index', 'kb_index']: batch[key] = list2tensor([x for x in data_dict[key]]) else: batch[key] = data_dict[key] turn_batches.append(batch) return turn_batches
def collect_metrics(self, outputs, target, emo_target): """ collect_metrics """ num_samples = target[0].size(0) num_words = target[1].sum().item() metrics = Pack(num_samples=num_samples) target_len = target[1] mask = sequence_mask(target_len) mask = mask.float() # logits = outputs.logits # nll = self.nll_loss(logits, target) out_copy = outputs.out_copy # out_copy batch x max_len x src target_loss = out_copy.gather(2, target[0].unsqueeze(-1)).squeeze(-1) target_loss = target_loss * mask target_loss += 1e-15 target_loss = target_loss.log() loss = -((target_loss.sum()) / num_words) out_emo = outputs.logits # batch x max_len x dim batch_size, max_len, class_num = out_emo.size() # out_emo=out_emo.view(batch_size*max_len, class_num) # emo_target=emo_target.view(-1) target_emo_loss = out_emo.gather( 2, emo_target[0].unsqueeze(-1)).squeeze(-1) target_len -= 1 mask_ = sequence_mask(target_len) mask_ = mask_.float() new_mask = mask.data.new(batch_size, max_len).zero_() # print(mask.size()) # print(new_mask.size()) new_mask[:, :max_len - 1] = mask_ target_emo_loss = target_emo_loss * new_mask target_emo_loss += 1e-15 target_emo_loss = target_emo_loss.log() emo_loss = -((target_emo_loss.sum()) / num_words) metrics.add(loss=loss) metrics.add(emo_loss=emo_loss) # 这里,我们将只计算 acc = accuracy(out_copy, target[0], mask=mask) metrics.add(acc=acc) return metrics
def encode_infe(self, inputs, elmo_embed, hidden=None): outputs = Pack() enc_inputs = inputs.num_src # input_raw = inputs.raw_src enc_outputs, enc_hidden = self.encoder.infer(enc_inputs, elmo_embed, hidden) if self.with_bridge: enc_hidden = self.bridge1(enc_hidden) layer, batch_size, dim = enc_hidden.size() dec_init_state = self.decoder.initialize_state( hidden=enc_hidden, input_feed=enc_hidden.data.new(batch_size, dim).zero_() \ .unsqueeze(1), attn_memory=enc_outputs if self.attn_mode else None, mask=inputs.mask[0]) return outputs, dec_init_state
def encode(self, inputs, hidden=None): outputs = Pack() enc_inputs = _, lengths = inputs.src[0][:, 1:-1], inputs.src[1] - 2 enc_outputs, enc_hidden = self.rnn_encoder(enc_inputs, hidden) # knowledge batch_size, sent_num, sent = inputs.cue[0].size() tmp_len = inputs.cue[1] tmp_len[tmp_len > 0] -= 2 cue_inputs = inputs.cue[0][:, :, 1:-1], tmp_len u = self.mem_encoder(cue_inputs, enc_hidden[-1]) dec_init_state = self.decoder.initialize_state( hidden=u.unsqueeze(0), attn_memory=enc_outputs if self.attn_mode else None, memory_lengths=lengths if self.attn_mode else None) return outputs, dec_init_state
def collect_metrics(self, outputs, target): num_samples = target.size(0) metrics = Pack(num_samples=num_samples) loss = 0 logits = outputs.logits nll = self.nll_loss(logits, target) num_words = target.ne(self.padding_idx).sum().item() acc = accuracy(logits, target, padding_idx=self.padding_idx) metrics.add(nll=(nll, num_words), acc=acc) loss += nll metrics.add(loss=loss) return metrics
def interact(self, src, cue=None): if src == "": return None inputs = Pack() src = self.src_field.numericalize([src]) inputs.add(src=list2tensor(src)) if cue is not None: cue = self.cue_field.numericalize([cue]) inputs.add(cue=list2tensor(cue)) if self.use_gpu: inputs = inputs.cuda() _, preds, _, _ = self.forward(inputs=inputs, num_candidates=1) pred = self.tgt_field.denumericalize(preds[0][0]) return pred
def encode(self, inputs, hidden=None): outputs = Pack() hist_inputs = _, lengths = inputs.src[0][:, 1:-1], inputs.src[1] - 2 # (batch_size, seq_length, hidden_size*num_directions) # (num_layers, batch_size, num_directions * hidden_size) hist_outputs, hist_hidden = self.hist_encoder(hist_inputs, hidden) if self.with_bridge: hist_hidden = self.bridge(hist_hidden) # knowledge batch_size, sent_num, sent = inputs.cue[0].size() tmp_len = inputs.cue[1] tmp_len[tmp_len > 0] -= 2 fact_inputs = inputs.cue[0].view(-1, sent)[:, 1:-1], tmp_len.view(-1) fact_enc_outputs, fact_enc_hidden = self.fact_encoder( fact_inputs, hidden) # print(fact_enc_outputs.size()) fact_outputs = fact_enc_outputs.view(batch_size, sent_num * (sent - 2), -1) # # (batch_size, sent_num, hidden_size) # fact_hidden = fact_enc_hidden[-1].view(batch_size, sent_num, -1) # # (batch_size, hidden_size) # fact_hidden = torch.sum(fact_hidden, 1).squeeze(1) # print(hist_hidden[-1].size(), hist_outputs.size(), fact_outputs.size()) # print(lengths) # print(tmp_len) dec_init_state = self.decoder.initialize_state( hidden=hist_hidden, # fact_hidden=fact_hidden, fact=inputs.cue[0][:, :, 1:-1].contiguous().view(batch_size, -1), hist=inputs.src[0][:, 1:-1], attn_fact=fact_outputs if self.attn_mode else None, attn_hist=hist_outputs if self.attn_mode else None, fact_lengths=tmp_len if self.attn_mode else None, hist_lengths=lengths if self.attn_mode else None) return outputs, dec_init_state
def encode(self, inputs, hidden=None): """ encode """ outputs = Pack() enc_inputs, lengths = inputs.num_src pos_inputs = inputs.num_pos[0] enc_outputs, enc_hidden = self.encoder(enc_inputs, pos_inputs, hidden) if self.with_bridge: enc_hidden = self.bridge1(enc_hidden) layer, batch_size, dim = enc_hidden.size() dec_init_state = self.decoder.initialize_state( hidden=enc_hidden, input_feed=enc_hidden.data.new(batch_size,dim).zero_() \ .unsqueeze(1), attn_memory=enc_outputs if self.attn_mode else None, mask= inputs.mask[0]) return outputs, dec_init_state
def encode(self, inputs, hidden=None): """ encode """ outputs = Pack() tmp_len = inputs.src[1] tmp_len[tmp_len > 0] -= 2 enc_inputs = _, lengths = inputs.src[0][:, :, 1:-1], tmp_len hiera_lengths = lengths.gt(0).long().sum(dim=1) enc_outputs, enc_hidden, _ = self.encoder(enc_inputs, hidden) if self.with_bridge: enc_hidden = self.bridge(enc_hidden) dec_init_state = self.decoder.initialize_state( hidden=enc_hidden, attn_memory=enc_outputs if self.attn_mode else None, memory_lengths=hiera_lengths if self.attn_mode else None) return outputs, dec_init_state
def collate(data_list): """ collate """ data_list1, data_list2 = zip(*data_list) batch1 = Pack() batch2 = Pack() data_list1 = list(data_list1) data_list2 = list(data_list2) for key in data_list1[0].keys(): batch1[key] = list2tensor([x[key] for x in data_list1]) if device >= 0: batch1 = batch1.cuda(device=device) for key in list(data_list2)[0].keys(): batch2[key] = list2tensor([x[key] for x in data_list2]) if device >= 0: batch2 = batch2.cuda(device=device) return batch1, batch2