Esempio n. 1
0
 def __init__(self, data_path):
     self.out_en_sent, self.out_cn_sent = self.get_dataset(data_path, sort=True)
     self.sp_eng = english_tokenizer_load()
     self.sp_chn = chinese_tokenizer_load()
     self.PAD = self.sp_eng.pad_id()  # 0
     self.BOS = self.sp_eng.bos_id()  # 2
     self.EOS = self.sp_eng.eos_id()  # 3
Esempio n. 2
0
def evaluate(data, model, mode='dev', use_beam=True):
    """在data上用训练好的模型进行预测,打印模型翻译结果"""
    sp_chn = chinese_tokenizer_load()
    trg = []
    res = []
    with torch.no_grad():
        # 在data的英文数据长度上遍历下标
        for batch in tqdm(data):
            # 对应的中文句子
            cn_sent = batch.trg_text
            src = batch.src
            src_mask = (src != 0).unsqueeze(-2)
            if use_beam:
                decode_result, _ = beam_search(model, src, src_mask,
                                               config.max_len,
                                               config.padding_idx,
                                               config.bos_idx, config.eos_idx,
                                               config.beam_size, config.device)
            else:
                decode_result = batch_greedy_decode(model,
                                                    src,
                                                    src_mask,
                                                    max_len=config.max_len)
            decode_result = [h[0] for h in decode_result]
            translation = [sp_chn.decode_ids(_s) for _s in decode_result]
            trg.extend(cn_sent)
            res.extend(translation)
    if mode == 'test':
        with open(config.output_path, "w") as fp:
            for i in range(len(trg)):
                line = "idx:" + str(i) + trg[i] + '|||' + res[i] + '\n'
                fp.write(line)
    trg = [trg]
    bleu = sacrebleu.corpus_bleu(res, trg, tokenize='zh')
    return float(bleu.score)
Esempio n. 3
0
def translate(src, model, use_beam=True):
    """用训练好的模型进行预测单句,打印模型翻译结果"""
    sp_chn = chinese_tokenizer_load()
    with torch.no_grad():
        model.load_state_dict(torch.load(config.model_path))
        model.eval()
        src_mask = (src != 0).unsqueeze(-2)
        if use_beam:
            decode_result, _ = beam_search(model, src, src_mask,
                                           config.max_len, config.padding_idx,
                                           config.bos_idx, config.eos_idx,
                                           config.beam_size, config.device)
            decode_result = [h[0] for h in decode_result]
        else:
            decode_result = batch_greedy_decode(model,
                                                src,
                                                src_mask,
                                                max_len=config.max_len)
        translation = [sp_chn.decode_ids(_s) for _s in decode_result]
        print(translation[0])