コード例 #1
0
def _fairseq_generate(
    complex_filepath,
    output_pred_filepath,
    checkpoint_paths,
    complex_dictionary_path,
    simple_dictionary_path,
    beam=5,
    hypothesis_num=1,
    lenpen=1.0,
    diverse_beam_groups=None,
    diverse_beam_strength=0.5,
    sampling=False,
    max_tokens=16384,
    source_lang='complex',
    target_lang='simple',
    **kwargs,
):
    # exp_dir must contain checkpoints/checkpoint_best.pt, and dict.{complex,simple}.txt
    # First copy input complex file to exp_dir and create dummy simple file
    with create_temp_dir() as temp_dir:
        new_complex_filepath = temp_dir / f'tmp.{source_lang}-{target_lang}.{source_lang}'
        dummy_simple_filepath = temp_dir / f'tmp.{source_lang}-{target_lang}.{target_lang}'
        shutil.copy(complex_filepath, new_complex_filepath)
        shutil.copy(complex_filepath, dummy_simple_filepath)
        shutil.copy(complex_dictionary_path,
                    temp_dir / f'dict.{source_lang}.txt')
        shutil.copy(simple_dictionary_path,
                    temp_dir / f'dict.{target_lang}.txt')
        args = f'''
        {temp_dir} --dataset-impl raw --gen-subset tmp --path {':'.join([str(path) for path in checkpoint_paths])}
        --beam {beam} --nbest {hypothesis_num} --lenpen {lenpen}
        --diverse-beam-groups {diverse_beam_groups if diverse_beam_groups is not None else -1} --diverse-beam-strength {diverse_beam_strength}
        --max-tokens {max_tokens}
        --model-overrides "{{'encoder_embed_path': None, 'decoder_embed_path': None}}"
        --skip-invalid-size-inputs-valid-test
        '''
        if sampling:
            args += f'--sampling --sampling-topk 10'
        # FIXME: if the kwargs are already present in the args string, they will appear twice but fairseq will take only the last one into account
        args += f' {args_dict_to_str(kwargs)}'
        args = remove_multiple_whitespaces(args.replace('\n', ' '))
        out_filepath = temp_dir / 'generation.out'
        with mute(mute_stderr=False):
            with log_std_streams(out_filepath):
                # evaluate model in batch mode
                print(f'fairseq-generate {args}')
                args = shlex.split(args)
                with mock_cli_args(args):
                    generate.cli_main()

        all_hypotheses = fairseq_parse_all_hypotheses(out_filepath)
        predictions = [
            hypotheses[hypothesis_num - 1] for hypotheses in all_hypotheses
        ]
        write_lines(predictions, output_pred_filepath)
コード例 #2
0
#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from fairseq_cli.generate import cli_main

if __name__ == '__main__':
    cli_main()
コード例 #3
0
ファイル: generate.py プロジェクト: teanakamura/fairseq
#!/usr/local/opt/python/bin/python3.7
# -*- coding: utf-8 -*-
import re
import sys

from fairseq_cli.generate import cli_main

if __name__ == '__main__':
    sys.argv[0] = re.sub(r'(-script\.pyw?|\.exe)?$', '', sys.argv[0])
    sys.exit(cli_main())
コード例 #4
0
def ls_cli_main(*args, **kwargs):
    user_path = pathlib.Path(__file__).parent.joinpath("fs_modules")
    sys.argv.extend(["--user-dir", str(user_path)])
    cli_main(*args, **kwargs)