def sample_sequence(cfg, model: JointSentiGPT2Model, tokenizer: GPT2Tokenizer, context_token: torch.Tensor, token_type: torch.Tensor, context_emotion: torch.Tensor, cls_mask: torch.Tensor, emotion_pad=0, speaker1_state=2, decoding_strategy='sampling'): cls_mask_extra = torch.LongTensor([[[1], [0], [0], [0]]]).to(cfg.device) context_len = context_token.shape[1] generated = context_token past, pred_response_emotion = None, None result = [] for step in range(cfg.max_decode_length): inputs = { 'input_ids': generated, 'token_type_ids': token_type, 'emotion_ids': context_emotion, 'pred_response_emotion_vector': pred_response_emotion, 'cls_mask': cls_mask, 'past': past, 'decoding': True } outputs = model.decoding( **inputs ) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states) pred_response_emotion, past = outputs[1:] next_token_logits = outputs[0][0, -1, :] / cfg.sampling_temperature if decoding_strategy == 'sampling': filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=cfg.top_k, top_p=cfg.top_p) prob = F.softmax(filtered_logits, dim=-1) next_token = torch.multinomial(prob, num_samples=1) else: next_token = torch.argmax(next_token_logits, dim=-1) next_token = next_token.unsqueeze(0) if next_token.item( ) == tokenizer.eos_token_id and step >= cfg_gpt.min_decode_length: break result.append(next_token.item()) generated = next_token.unsqueeze(0) token_type = torch.LongTensor([[speaker1_state]]).to(cfg.device) cls_mask = torch.cat((cls_mask, cls_mask_extra), dim=-1) # generated = generated[0, context_len:].tolist() result = [ token_id for token_id in result if token_id not in cfg.special_id_list ] text = tokenizer.decode(result, skip_special_tokens=True, clean_up_tokenization_spaces=False) text = text.replace("\n", "").replace("\r", "") return text
def evaluation(cfg, model: JointSentiGPT2Model, tokenizer: GPT2Tokenizer, test_dataset: List[Dict], test_dataloader: DataLoader, d_type): src, hypothesis, dialog_situation_label = [], [], [] model.eval() with torch.no_grad(): for one_batch in tqdm(test_dataset, desc=f'decoding {d_type}...'): context_id = one_batch['input_ids'] if d_type == 'test': dialog_situation_label.append(one_batch['label']) src.append(tokenizer.decode(context_id)) # [seq_length] -> [1,seq_length] token_type_id = torch.LongTensor( one_batch['token_type_ids']).unsqueeze(0).to(cfg.device) if cfg.emotion_cls == 'coarse': emotion_id = torch.LongTensor( one_batch['coarse_grained_emotion_ids']).unsqueeze(0).to( cfg.device) else: emotion_id = torch.LongTensor( one_batch['fine_grained_emotion_ids']).unsqueeze(0).to( cfg.device) context_id = torch.LongTensor(context_id).unsqueeze(0).to( cfg.device) # [2, seq_length] -> [1, 2, seq_length] cls_mask = torch.LongTensor(one_batch['cls_mask']).unsqueeze(0).to( cfg.device) hyp = sample_sequence(cfg, model, tokenizer, context_id, token_type_id, emotion_id, cls_mask, decoding_strategy=cfg.decoding_method) hypothesis.append(hyp) suffix = 'greedy' if cfg.decoding_method == 'greedy' else f'sampling_topk{cfg.top_k}_topp{cfg.top_p}_tau{cfg.sampling_temperature}' if d_type in ['train', 'valid']: hyp_file = os.path.join( cfg.save_dir, f"epo{cfg.best_epoch}_{d_type}_hyp{cfg.min_decode_length}_{cfg.max_decode_length}_{suffix}.txt" ) with open(hyp_file, 'w', encoding='utf-8') as f_hyp: for hypo in hypothesis: f_hyp.writelines(hypo.strip() + '\n') else: hyp_file = os.path.join( cfg.save_dir, f"epo{cfg.best_epoch}_test_hyp{cfg.min_decode_length}_{cfg.max_decode_length}_{suffix}.txt" ) check_file = os.path.join( cfg.save_dir, f"epo{cfg.best_epoch}_test_check{cfg.min_decode_length}_{cfg.max_decode_length}_{suffix}.txt" ) if cfg.lower: result_file = os.path.join( cfg.save_dir, f"epo{cfg.best_epoch}_result{cfg.min_decode_length}_{cfg.max_decode_length}_lower_{suffix}.json" ) else: result_file = os.path.join( cfg.save_dir, f"epo{cfg.best_epoch}_result{cfg.min_decode_length}_{cfg.max_decode_length}_{suffix}.json" ) ref_txt = cfg.ref_file with open(check_file, 'w', encoding='utf-8') as f_check, open(hyp_file, 'w', encoding='utf-8') as f_hyp, \ open(ref_txt, 'r', encoding='utf-8') as f_r: for dialog_history, situation_label, hypo, ref in zip( src, dialog_situation_label, hypothesis, f_r): f_check.writelines(f"situation: {situation_label}\n") f_check.writelines(f"history: {dialog_history}\n") f_check.writelines(f"hyp: {hypo.strip()}\n") f_check.writelines(f"ref: {ref.strip()}\n") f_check.writelines("\n") f_hyp.writelines(hypo.strip() + '\n') # compute regular metric result = compute_metrics(hyp_file=hyp_file, ref_file=ref_txt, glove_path=cfg.glove_path, lower=cfg.lower, space_token=False) # compute perplexity NLL_Loss, _, _ = validation(cfg, model, test_dataloader, decoding=True) ppl = np.exp(NLL_Loss) print(f"perplexity: {ppl:.6f}") result.update({"perplexity": ppl}) with open(result_file, 'w', encoding='utf-8') as f: json.dump(result, f)