def set_default_args(args): # Prevents generate from printing individual translated sentences when # calculating BLEU score. args.quiet = True # Set default init method for multi-GPU training if the user didn't specify # them. if args.distributed_world_size > 1: args.distributed_init_method = ( f"tcp://localhost:{random.randint(10000, 20000)}" if not args.distributed_init_method else args.distributed_init_method) if args.local_num_gpus > args.distributed_world_size: raise ValueError( f"--local-num-gpus={args.local_num_gpus} must be " f"<= --distributed-world-size={args.distributed_world_size}.") if args.local_num_gpus > torch.cuda.device_count(): raise ValueError( f"--local-num-gpus={args.local_num_gpus} must be " f"<= the number of GPUs: {torch.cuda.device_count()}.") if args.fp16 and getattr(args, "adversary", False): print( "Warning: disabling fp16 training since it's not supported by AdversarialTrainer." ) args.fp16 = False if not args.source_vocab_file: args.source_vocab_file = pytorch_translate_dictionary.default_dictionary_path( save_dir=args.save_dir, dialect=args.source_lang) if not args.target_vocab_file: args.target_vocab_file = pytorch_translate_dictionary.default_dictionary_path( save_dir=args.save_dir, dialect=args.target_lang) if args.arch in constants.ARCHS_FOR_CHAR_SOURCE and not args.char_source_vocab_file: args.char_source_vocab_file = ( pytorch_translate_dictionary.default_char_dictionary_path( save_dir=args.save_dir, dialect=args.source_lang)) if args.arch in constants.ARCHS_FOR_CHAR_TARGET and not args.char_target_vocab_file: args.char_target_vocab_file = ( pytorch_translate_dictionary.default_char_dictionary_path( save_dir=args.save_dir, dialect=args.target_lang)) if args.multiling_encoder_lang and not args.multiling_source_vocab_file: args.multiling_source_vocab_file = [ pytorch_translate_dictionary.default_dictionary_path( save_dir=args.save_dir, dialect=f"src-{l}") for l in args.multiling_encoder_lang ] if args.multiling_decoder_lang and not args.multiling_target_vocab_file: args.multiling_target_vocab_file = [ pytorch_translate_dictionary.default_dictionary_path( save_dir=args.save_dir, dialect=f"trg-{l}") for l in args.multiling_decoder_lang ]
def get_dict_paths(vocabulary_args: Optional[List[str]], langs: List[str], save_dir: str) -> Dict[str, str]: """ Extract dictionary files based on --vocabulary argument, for the given languages `langs`. vocabulary_arg: Optional[List[str]] where each element is a str with the format "lang:vocab_file" """ dicts = {} if vocabulary_args is not None: for vocab_config in vocabulary_args: # vocab_config is in the format "lang:vocab_file" lang, vocab = vocab_config.split(":") if lang in langs: dicts[lang] = vocab for lang in langs: if lang not in dicts: dicts[lang] = pytorch_translate_dictionary.default_dictionary_path( save_dir=save_dir, dialect=lang) return dicts