def decode(args: Dict[str, str]): """ performs decoding on a test set, and save the best-scoring decoding results. If the target gold-standard sentences are given, the function also computes corpus-level BLEU score. """ test_src_dir = os.path.join(args.test_dir, args.input_col.lower()) test_tgt_dir = os.path.join(args.test_dir, args.output_col.lower()) print(f"load test source sentences from [{test_src_dir}]", file=sys.stderr) test_data_src = read_corpus(test_src_dir, source='src') if test_tgt_dir: print(f"load test target sentences from [{test_tgt_dir}]", file=sys.stderr) test_data_tgt = read_corpus(test_tgt_dir, source='tgt') model_path = os.path.join(args.model_dir, 'model.bin') print(f"load model from {model_path}", file=sys.stderr) model = NMT.load(model_path) if args.cuda: model = model.to(torch.device("cuda:0")) hypotheses = beam_search(model, test_data_src, beam_size=int(args.beam_size), max_decoding_time_step=int(args.max_decoding_time_step)) top_hypotheses = [hyps[0] for hyps in hypotheses] bleu_score = compute_corpus_level_bleu_score(test_data_tgt, top_hypotheses) print(f'Corpus BLEU: {bleu_score}', file=sys.stderr) output_path = os.path.join(args.eval_dir, 'decode.txt') with open(output_path, 'w') as f: f.write(str(bleu_score))
def init(): global model model_dir = Model.get_model_path('arxiv-nmt-pipeline') model_path = os.path.join(model_dir, 'model.bin') model = NMT.load(model_path) model.eval()