def decode(args: Dict[str, str]): """ Performs decoding on a test set, and save the best-scoring decoding results. If the target gold-standard sentences are given, the function also computes corpus-level BLEU score. @param args (Dict): args from cmd line """ print("load test source sentences from [{}]".format( args['TEST_SOURCE_FILE']), file=sys.stderr) test_data_src = read_corpus(args['TEST_SOURCE_FILE'], source='src') if args['TEST_TARGET_FILE']: print("load test target sentences from [{}]".format( args['TEST_TARGET_FILE']), file=sys.stderr) test_data_tgt = read_corpus(args['TEST_TARGET_FILE'], source='tgt') else: test_data_tgt = None print("load model from {}".format(args['MODEL_PATH']), file=sys.stderr) model = NMT.load(args['MODEL_PATH']) if args['--cuda']: model = model.to(torch.device("cuda:0")) beam_size = int(args['--beam-size']) max_decoding_time_step = int(args['--max-decoding-time-step']) output_file = args['OUTPUT_FILE'] decode_with_params(model, test_data_src, test_data_tgt, beam_size, max_decoding_time_step, output_file)
def decode(args: Dict[str, str]): """ Performs decoding on a test set, and save the best-scoring decoding results. If the target gold-standard sentences are given, the function also computes corpus-level BLEU score. @param args (Dict): args from cmd line """ print("load test source sentences from [{}]".format(args['TEST_SOURCE_FILE']), file=sys.stderr) test_data_src = read_corpus(args['TEST_SOURCE_FILE'], source='src') if args['TEST_TARGET_FILE']: print("load test target sentences from [{}]".format(args['TEST_TARGET_FILE']), file=sys.stderr) test_data_tgt = read_corpus(args['TEST_TARGET_FILE'], source='tgt') print("load model from {}".format(args['MODEL_PATH']), file=sys.stderr) model = NMT.load(args['MODEL_PATH']) if args['--cuda']: model = model.to(torch.device("cuda:0")) hypotheses = beam_search(model, test_data_src, beam_size=int(args['--beam-size']), max_decoding_time_step=int(args['--max-decoding-time-step'])) if args['TEST_TARGET_FILE']: top_hypotheses = [hyps[0] for hyps in hypotheses] bleu_score = compute_corpus_level_bleu_score(test_data_tgt, top_hypotheses) print('Corpus BLEU: {}'.format(bleu_score * 100), file=sys.stderr) with open(args['OUTPUT_FILE'], 'w') as f: for src_sent, hyps in zip(test_data_src, hypotheses): top_hyp = hyps[0] hyp_sent = ' '.join(top_hyp.value) f.write(hyp_sent + '\n')
def test(self): print('*' * 20, 'start test', '*' * 20) self.model = NMT.load(self.hparams.model_save_path, self.device) sources, references, hypotheses = self.beam_search() bleu_score = compute_corpus_level_bleu_score(references, hypotheses) print('Corpus BLEU: {}'.format(bleu_score * 100)) with open(self.hparams.test_res_path, 'w') as f: for src_sent, hypo in zip(sources, hypotheses): src_sent = ' '.join(src_sent) hypo_sent = ' '.join(hypo.value) f.write(src_sent + '\n' + hypo_sent + '\n\n') print('save test result to {}'.format(self.hparams.test_res_path)) print('*' * 20, 'end test', '*' * 20)