def validate_and_set_default_args(args): # Prevents generate from printing individual translated sentences when # calculating BLEU score. args.quiet = True if not args.source_vocab_file: args.source_vocab_file = pytorch_translate_dictionary.default_dictionary_path( save_dir=args.save_dir, dialect=args.source_lang) if not args.target_vocab_file: args.target_vocab_file = pytorch_translate_dictionary.default_dictionary_path( save_dir=args.save_dir, dialect=args.target_lang) if args.arch == "char_source" and not args.char_source_vocab_file: args.char_source_vocab_file = pytorch_translate_dictionary.default_char_dictionary_path( save_dir=args.save_dir, dialect=args.source_lang) pytorch_translate_options.validate_preprocessing_args(args) pytorch_translate_options.validate_generation_args(args) if args.multiling_encoder_lang and not args.multiling_source_vocab_file: args.multiling_source_vocab_file = [ pytorch_translate_dictionary.default_dictionary_path( save_dir=args.save_dir, dialect=f"src-{l}") for l in args.multiling_encoder_lang ] if args.multiling_decoder_lang and not args.multiling_target_vocab_file: args.multiling_target_vocab_file = [ pytorch_translate_dictionary.default_dictionary_path( save_dir=args.save_dir, dialect=f"trg-{l}") for l in args.multiling_decoder_lang ]
def validate_and_set_default_args(args): # Prevents generate from printing individual translated sentences when # calculating BLEU score. args.quiet = True pytorch_translate_options.check_unsupported_fairseq_flags(args) # Set default init method for multi-GPU training if the user didn't specify # them. if args.distributed_world_size > 1: args.distributed_init_method = ( f"tcp://localhost:{random.randint(10000, 20000)}" if not args.distributed_init_method else args.distributed_init_method) if args.local_num_gpus > args.distributed_world_size: raise ValueError( f"--local-num-gpus={args.local_num_gpus} must be " f"<= --distributed-world-size={args.distributed_world_size}.") if args.local_num_gpus > torch.cuda.device_count(): raise ValueError( f"--local-num-gpus={args.local_num_gpus} must be " f"<= the number of GPUs: {torch.cuda.device_count()}.") if args.fp16 and getattr(args, "adversary", False): print( "Warning: disabling fp16 training since it's not supported by AdversarialTrainer." ) args.fp16 = False if not args.source_vocab_file: args.source_vocab_file = pytorch_translate_dictionary.default_dictionary_path( save_dir=args.save_dir, dialect=args.source_lang) if not args.target_vocab_file: args.target_vocab_file = pytorch_translate_dictionary.default_dictionary_path( save_dir=args.save_dir, dialect=args.target_lang) if (args.arch == "char_source" or args.arch == "char_source_transformer" or args.arch == "char_source_hybrid") and not args.char_source_vocab_file: args.char_source_vocab_file = pytorch_translate_dictionary.default_char_dictionary_path( save_dir=args.save_dir, dialect=args.source_lang) pytorch_translate_options.validate_preprocessing_args(args) pytorch_translate_options.validate_generation_args(args) if args.multiling_encoder_lang and not args.multiling_source_vocab_file: args.multiling_source_vocab_file = [ pytorch_translate_dictionary.default_dictionary_path( save_dir=args.save_dir, dialect=f"src-{l}") for l in args.multiling_encoder_lang ] if args.multiling_decoder_lang and not args.multiling_target_vocab_file: args.multiling_target_vocab_file = [ pytorch_translate_dictionary.default_dictionary_path( save_dir=args.save_dir, dialect=f"trg-{l}") for l in args.multiling_decoder_lang ]
def validate_and_set_default_args(args): # Prevents generate from printing individual translated sentences when # calculating BLEU score. args.quiet = True if args.distributed_world_size < torch.cuda.device_count(): raise ValueError( f"--distributed-world-size={args.distributed_world_size} " f"must be >= the number of GPUs: {torch.cuda.device_count()}." ) # Set default init method for multi-GPU training if the user didn't specify # them. if args.distributed_world_size > 1: args.distributed_init_method = ( f"tcp://localhost:{random.randint(10000, 20000)}" if not args.distributed_init_method else args.distributed_init_method ) if not args.source_vocab_file: args.source_vocab_file = pytorch_translate_dictionary.default_dictionary_path( save_dir=args.save_dir, dialect=args.source_lang ) if not args.target_vocab_file: args.target_vocab_file = pytorch_translate_dictionary.default_dictionary_path( save_dir=args.save_dir, dialect=args.target_lang ) if args.arch == "char_source" and not args.char_source_vocab_file: args.char_source_vocab_file = pytorch_translate_dictionary.default_char_dictionary_path( save_dir=args.save_dir, dialect=args.source_lang ) pytorch_translate_options.validate_preprocessing_args(args) pytorch_translate_options.validate_generation_args(args) if args.multiling_encoder_lang and not args.multiling_source_vocab_file: args.multiling_source_vocab_file = [ pytorch_translate_dictionary.default_dictionary_path( save_dir=args.save_dir, dialect=f"src-{l}" ) for l in args.multiling_encoder_lang ] if args.multiling_decoder_lang and not args.multiling_target_vocab_file: args.multiling_target_vocab_file = [ pytorch_translate_dictionary.default_dictionary_path( save_dir=args.save_dir, dialect=f"trg-{l}" ) for l in args.multiling_decoder_lang ]
def validate_args(args): pytorch_translate_options.validate_generation_args(args) assert args.path is not None, "--path required for generation!" assert args.source_vocab_file and os.path.isfile( args.source_vocab_file ), "Please specify a valid file for --source-vocab-file" assert args.target_vocab_file and os.path.isfile( args.target_vocab_file ), "Please specify a valid file for --target-vocab_file" assert (all((src_file and os.path.isfile(src_file)) for src_file in args.source_text_file) ), "Please specify a valid file for --source-text-file" assert (args.target_text_file and os.path.isfile(args.target_text_file) ), "Please specify a valid file for --target-text-file"
def validate_args(args): pytorch_translate_options.check_unsupported_fairseq_flags(args) pytorch_translate_options.validate_preprocessing_args(args) pytorch_translate_options.validate_generation_args(args)