예제 #1
0
def sample_sequence(cfg,
                    model: JointSentiGPT2Model,
                    tokenizer: GPT2Tokenizer,
                    context_token: torch.Tensor,
                    token_type: torch.Tensor,
                    context_emotion: torch.Tensor,
                    cls_mask: torch.Tensor,
                    emotion_pad=0,
                    speaker1_state=2,
                    decoding_strategy='sampling'):
    cls_mask_extra = torch.LongTensor([[[1], [0], [0], [0]]]).to(cfg.device)

    context_len = context_token.shape[1]
    generated = context_token

    past, pred_response_emotion = None, None
    result = []
    for step in range(cfg.max_decode_length):
        inputs = {
            'input_ids': generated,
            'token_type_ids': token_type,
            'emotion_ids': context_emotion,
            'pred_response_emotion_vector': pred_response_emotion,
            'cls_mask': cls_mask,
            'past': past,
            'decoding': True
        }
        outputs = model.decoding(
            **inputs
        )  # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
        pred_response_emotion, past = outputs[1:]
        next_token_logits = outputs[0][0, -1, :] / cfg.sampling_temperature
        if decoding_strategy == 'sampling':
            filtered_logits = top_k_top_p_filtering(next_token_logits,
                                                    top_k=cfg.top_k,
                                                    top_p=cfg.top_p)
            prob = F.softmax(filtered_logits, dim=-1)
            next_token = torch.multinomial(prob, num_samples=1)
        else:
            next_token = torch.argmax(next_token_logits, dim=-1)
            next_token = next_token.unsqueeze(0)

        if next_token.item(
        ) == tokenizer.eos_token_id and step >= cfg_gpt.min_decode_length:
            break

        result.append(next_token.item())
        generated = next_token.unsqueeze(0)
        token_type = torch.LongTensor([[speaker1_state]]).to(cfg.device)
        cls_mask = torch.cat((cls_mask, cls_mask_extra), dim=-1)

    # generated = generated[0, context_len:].tolist()
    result = [
        token_id for token_id in result if token_id not in cfg.special_id_list
    ]
    text = tokenizer.decode(result,
                            skip_special_tokens=True,
                            clean_up_tokenization_spaces=False)
    text = text.replace("\n", "").replace("\r", "")
    return text
예제 #2
0
def evaluation(cfg, model: JointSentiGPT2Model, tokenizer: GPT2Tokenizer,
               test_dataset: List[Dict], test_dataloader: DataLoader, d_type):
    src, hypothesis, dialog_situation_label = [], [], []
    model.eval()
    with torch.no_grad():
        for one_batch in tqdm(test_dataset, desc=f'decoding {d_type}...'):
            context_id = one_batch['input_ids']
            if d_type == 'test':
                dialog_situation_label.append(one_batch['label'])
            src.append(tokenizer.decode(context_id))

            # [seq_length] -> [1,seq_length]
            token_type_id = torch.LongTensor(
                one_batch['token_type_ids']).unsqueeze(0).to(cfg.device)
            if cfg.emotion_cls == 'coarse':
                emotion_id = torch.LongTensor(
                    one_batch['coarse_grained_emotion_ids']).unsqueeze(0).to(
                        cfg.device)
            else:
                emotion_id = torch.LongTensor(
                    one_batch['fine_grained_emotion_ids']).unsqueeze(0).to(
                        cfg.device)
            context_id = torch.LongTensor(context_id).unsqueeze(0).to(
                cfg.device)

            # [2, seq_length] -> [1, 2, seq_length]
            cls_mask = torch.LongTensor(one_batch['cls_mask']).unsqueeze(0).to(
                cfg.device)

            hyp = sample_sequence(cfg,
                                  model,
                                  tokenizer,
                                  context_id,
                                  token_type_id,
                                  emotion_id,
                                  cls_mask,
                                  decoding_strategy=cfg.decoding_method)
            hypothesis.append(hyp)

    suffix = 'greedy' if cfg.decoding_method == 'greedy' else f'sampling_topk{cfg.top_k}_topp{cfg.top_p}_tau{cfg.sampling_temperature}'
    if d_type in ['train', 'valid']:
        hyp_file = os.path.join(
            cfg.save_dir,
            f"epo{cfg.best_epoch}_{d_type}_hyp{cfg.min_decode_length}_{cfg.max_decode_length}_{suffix}.txt"
        )
        with open(hyp_file, 'w', encoding='utf-8') as f_hyp:
            for hypo in hypothesis:
                f_hyp.writelines(hypo.strip() + '\n')
    else:
        hyp_file = os.path.join(
            cfg.save_dir,
            f"epo{cfg.best_epoch}_test_hyp{cfg.min_decode_length}_{cfg.max_decode_length}_{suffix}.txt"
        )
        check_file = os.path.join(
            cfg.save_dir,
            f"epo{cfg.best_epoch}_test_check{cfg.min_decode_length}_{cfg.max_decode_length}_{suffix}.txt"
        )

        if cfg.lower:
            result_file = os.path.join(
                cfg.save_dir,
                f"epo{cfg.best_epoch}_result{cfg.min_decode_length}_{cfg.max_decode_length}_lower_{suffix}.json"
            )
        else:
            result_file = os.path.join(
                cfg.save_dir,
                f"epo{cfg.best_epoch}_result{cfg.min_decode_length}_{cfg.max_decode_length}_{suffix}.json"
            )

        ref_txt = cfg.ref_file

        with open(check_file, 'w', encoding='utf-8') as f_check, open(hyp_file, 'w', encoding='utf-8') as f_hyp, \
                open(ref_txt, 'r', encoding='utf-8') as f_r:
            for dialog_history, situation_label, hypo, ref in zip(
                    src, dialog_situation_label, hypothesis, f_r):
                f_check.writelines(f"situation: {situation_label}\n")
                f_check.writelines(f"history: {dialog_history}\n")
                f_check.writelines(f"hyp: {hypo.strip()}\n")
                f_check.writelines(f"ref: {ref.strip()}\n")
                f_check.writelines("\n")

                f_hyp.writelines(hypo.strip() + '\n')

        # compute regular metric
        result = compute_metrics(hyp_file=hyp_file,
                                 ref_file=ref_txt,
                                 glove_path=cfg.glove_path,
                                 lower=cfg.lower,
                                 space_token=False)
        # compute perplexity
        NLL_Loss, _, _ = validation(cfg, model, test_dataloader, decoding=True)
        ppl = np.exp(NLL_Loss)
        print(f"perplexity: {ppl:.6f}")
        result.update({"perplexity": ppl})

        with open(result_file, 'w', encoding='utf-8') as f:
            json.dump(result, f)