コード例 #1
0
ファイル: test_binaries.py プロジェクト: yf1291/nlp4
def eval_lm_main(data_dir):
    eval_lm_parser = options.get_eval_lm_parser()
    eval_lm_args = options.parse_args_and_arch(
        eval_lm_parser,
        [
            data_dir,
            '--path', os.path.join(data_dir, 'checkpoint_last.pt'),
            '--no-progress-bar',
        ],
    )
    eval_lm.main(eval_lm_args)
コード例 #2
0
ファイル: rerank_utils.py プロジェクト: veralily/fairseq
def lm_scoring(
    preprocess_directory,
    bpe_status,
    gen_output,
    pre_gen,
    cur_lm_dict,
    cur_lm_name,
    cur_language_model,
    cur_lm_bpe_code,
    batch_size,
    lm_score_file,
    target_lang,
    source_lang,
    prefix_len=None,
):
    if prefix_len is not None:
        assert (bpe_status == "different"
                ), "bpe status must be different to use prefix len"
    if bpe_status == "no bpe":
        # run lm on output without bpe
        write_reprocessed(
            gen_output.no_bpe_source,
            gen_output.no_bpe_hypo,
            gen_output.no_bpe_target,
            pre_gen + "/rescore_data_no_bpe.de",
            pre_gen + "/rescore_data_no_bpe.en",
            pre_gen + "/reference_file_no_bpe",
        )

        preprocess_lm_param = [
            "--only-source",
            "--trainpref",
            pre_gen + "/rescore_data_no_bpe." + target_lang,
            "--srcdict",
            cur_lm_dict,
            "--destdir",
            preprocess_directory,
        ]
        preprocess_parser = options.get_preprocessing_parser()
        input_args = preprocess_parser.parse_args(preprocess_lm_param)
        preprocess.main(input_args)

        eval_lm_param = [
            preprocess_directory,
            "--path",
            cur_language_model,
            "--output-word-probs",
            "--batch-size",
            str(batch_size),
            "--max-tokens",
            "1024",
            "--sample-break-mode",
            "eos",
            "--gen-subset",
            "train",
        ]

        eval_lm_parser = options.get_eval_lm_parser()
        input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)

        with open(lm_score_file, "w") as f:
            with redirect_stdout(f):
                eval_lm.main(input_args)

    elif bpe_status == "shared":
        preprocess_lm_param = [
            "--only-source",
            "--trainpref",
            pre_gen + "/rescore_data." + target_lang,
            "--srcdict",
            cur_lm_dict,
            "--destdir",
            preprocess_directory,
        ]
        preprocess_parser = options.get_preprocessing_parser()
        input_args = preprocess_parser.parse_args(preprocess_lm_param)
        preprocess.main(input_args)

        eval_lm_param = [
            preprocess_directory,
            "--path",
            cur_language_model,
            "--output-word-probs",
            "--batch-size",
            str(batch_size),
            "--sample-break-mode",
            "eos",
            "--gen-subset",
            "train",
        ]

        eval_lm_parser = options.get_eval_lm_parser()
        input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)

        with open(lm_score_file, "w") as f:
            with redirect_stdout(f):
                eval_lm.main(input_args)

    elif bpe_status == "different":
        rescore_file = pre_gen + "/rescore_data_no_bpe"
        rescore_bpe = pre_gen + "/rescore_data_new_bpe"

        rescore_file += "."
        rescore_bpe += "."

        write_reprocessed(
            gen_output.no_bpe_source,
            gen_output.no_bpe_hypo,
            gen_output.no_bpe_target,
            rescore_file + source_lang,
            rescore_file + target_lang,
            pre_gen + "/reference_file_no_bpe",
            bpe_symbol=None,
        )

        # apply LM bpe to nbest list
        bpe_src_param = [
            "-c",
            cur_lm_bpe_code,
            "--input",
            rescore_file + target_lang,
            "--output",
            rescore_bpe + target_lang,
        ]
        subprocess.call(
            [
                "python",
                os.path.join(os.path.dirname(__file__),
                             "subword-nmt/subword_nmt/apply_bpe.py"),
            ] + bpe_src_param,
            shell=False,
        )
        # uncomment to use fastbpe instead of subword-nmt bpe
        # bpe_src_param = [rescore_bpe+target_lang, rescore_file+target_lang, cur_lm_bpe_code]
        # subprocess.call(["/private/home/edunov/fastBPE/fast", "applybpe"] + bpe_src_param, shell=False)

        preprocess_dir = preprocess_directory

        preprocess_lm_param = [
            "--only-source",
            "--trainpref",
            rescore_bpe + target_lang,
            "--srcdict",
            cur_lm_dict,
            "--destdir",
            preprocess_dir,
        ]
        preprocess_parser = options.get_preprocessing_parser()
        input_args = preprocess_parser.parse_args(preprocess_lm_param)
        preprocess.main(input_args)

        eval_lm_param = [
            preprocess_dir,
            "--path",
            cur_language_model,
            "--output-word-probs",
            "--batch-size",
            str(batch_size),
            "--max-tokens",
            "1024",
            "--sample-break-mode",
            "eos",
            "--gen-subset",
            "train",
        ]

        eval_lm_parser = options.get_eval_lm_parser()
        input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)

        with open(lm_score_file, "w") as f:
            with redirect_stdout(f):
                eval_lm.main(input_args)