Beispiel #1
0
def set_default_args(args):
    # Prevents generate from printing individual translated sentences when
    # calculating BLEU score.
    args.quiet = True

    # Set default init method for multi-GPU training if the user didn't specify
    # them.
    if args.distributed_world_size > 1:
        args.distributed_init_method = (
            f"tcp://localhost:{random.randint(10000, 20000)}" if
            not args.distributed_init_method else args.distributed_init_method)

        if args.local_num_gpus > args.distributed_world_size:
            raise ValueError(
                f"--local-num-gpus={args.local_num_gpus} must be "
                f"<= --distributed-world-size={args.distributed_world_size}.")
        if args.local_num_gpus > torch.cuda.device_count():
            raise ValueError(
                f"--local-num-gpus={args.local_num_gpus} must be "
                f"<= the number of GPUs: {torch.cuda.device_count()}.")

    if args.fp16 and getattr(args, "adversary", False):
        print(
            "Warning: disabling fp16 training since it's not supported by AdversarialTrainer."
        )
        args.fp16 = False

    if not args.source_vocab_file:
        args.source_vocab_file = pytorch_translate_dictionary.default_dictionary_path(
            save_dir=args.save_dir, dialect=args.source_lang)
    if not args.target_vocab_file:
        args.target_vocab_file = pytorch_translate_dictionary.default_dictionary_path(
            save_dir=args.save_dir, dialect=args.target_lang)

    if args.arch in constants.ARCHS_FOR_CHAR_SOURCE and not args.char_source_vocab_file:
        args.char_source_vocab_file = (
            pytorch_translate_dictionary.default_char_dictionary_path(
                save_dir=args.save_dir, dialect=args.source_lang))
    if args.arch in constants.ARCHS_FOR_CHAR_TARGET and not args.char_target_vocab_file:
        args.char_target_vocab_file = (
            pytorch_translate_dictionary.default_char_dictionary_path(
                save_dir=args.save_dir, dialect=args.target_lang))

    if args.multiling_encoder_lang and not args.multiling_source_vocab_file:
        args.multiling_source_vocab_file = [
            pytorch_translate_dictionary.default_dictionary_path(
                save_dir=args.save_dir, dialect=f"src-{l}")
            for l in args.multiling_encoder_lang
        ]
    if args.multiling_decoder_lang and not args.multiling_target_vocab_file:
        args.multiling_target_vocab_file = [
            pytorch_translate_dictionary.default_dictionary_path(
                save_dir=args.save_dir, dialect=f"trg-{l}")
            for l in args.multiling_decoder_lang
        ]
Beispiel #2
0
def get_dict_paths(vocabulary_args: Optional[List[str]], langs: List[str],
                   save_dir: str) -> Dict[str, str]:
    """
    Extract dictionary files based on --vocabulary argument, for the given
    languages `langs`.
    vocabulary_arg: Optional[List[str]] where each element is a str with the format
    "lang:vocab_file"
    """
    dicts = {}
    if vocabulary_args is not None:
        for vocab_config in vocabulary_args:
            # vocab_config is in the format "lang:vocab_file"
            lang, vocab = vocab_config.split(":")
            if lang in langs:
                dicts[lang] = vocab
    for lang in langs:
        if lang not in dicts:
            dicts[lang] = pytorch_translate_dictionary.default_dictionary_path(
                save_dir=save_dir, dialect=lang)
    return dicts