def __init__(self, data_path): self.out_en_sent, self.out_cn_sent = self.get_dataset(data_path, sort=True) self.sp_eng = english_tokenizer_load() self.sp_chn = chinese_tokenizer_load() self.PAD = self.sp_eng.pad_id() # 0 self.BOS = self.sp_eng.bos_id() # 2 self.EOS = self.sp_eng.eos_id() # 3
def evaluate(data, model, mode='dev', use_beam=True): """在data上用训练好的模型进行预测,打印模型翻译结果""" sp_chn = chinese_tokenizer_load() trg = [] res = [] with torch.no_grad(): # 在data的英文数据长度上遍历下标 for batch in tqdm(data): # 对应的中文句子 cn_sent = batch.trg_text src = batch.src src_mask = (src != 0).unsqueeze(-2) if use_beam: decode_result, _ = beam_search(model, src, src_mask, config.max_len, config.padding_idx, config.bos_idx, config.eos_idx, config.beam_size, config.device) else: decode_result = batch_greedy_decode(model, src, src_mask, max_len=config.max_len) decode_result = [h[0] for h in decode_result] translation = [sp_chn.decode_ids(_s) for _s in decode_result] trg.extend(cn_sent) res.extend(translation) if mode == 'test': with open(config.output_path, "w") as fp: for i in range(len(trg)): line = "idx:" + str(i) + trg[i] + '|||' + res[i] + '\n' fp.write(line) trg = [trg] bleu = sacrebleu.corpus_bleu(res, trg, tokenize='zh') return float(bleu.score)
def translate(src, model, use_beam=True): """用训练好的模型进行预测单句,打印模型翻译结果""" sp_chn = chinese_tokenizer_load() with torch.no_grad(): model.load_state_dict(torch.load(config.model_path)) model.eval() src_mask = (src != 0).unsqueeze(-2) if use_beam: decode_result, _ = beam_search(model, src, src_mask, config.max_len, config.padding_idx, config.bos_idx, config.eos_idx, config.beam_size, config.device) decode_result = [h[0] for h in decode_result] else: decode_result = batch_greedy_decode(model, src, src_mask, max_len=config.max_len) translation = [sp_chn.decode_ids(_s) for _s in decode_result] print(translation[0])