def validate_and_set_default_args(args): # Prevents generate from printing individual translated sentences when # calculating BLEU score. args.quiet = True if not args.source_vocab_file: args.source_vocab_file = pytorch_translate_dictionary.default_dictionary_path( save_dir=args.save_dir, dialect=args.source_lang) if not args.target_vocab_file: args.target_vocab_file = pytorch_translate_dictionary.default_dictionary_path( save_dir=args.save_dir, dialect=args.target_lang) if args.arch == "char_source" and not args.char_source_vocab_file: args.char_source_vocab_file = pytorch_translate_dictionary.default_char_dictionary_path( save_dir=args.save_dir, dialect=args.source_lang) pytorch_translate_options.validate_preprocessing_args(args) pytorch_translate_options.validate_generation_args(args) if args.multiling_encoder_lang and not args.multiling_source_vocab_file: args.multiling_source_vocab_file = [ pytorch_translate_dictionary.default_dictionary_path( save_dir=args.save_dir, dialect=f"src-{l}") for l in args.multiling_encoder_lang ] if args.multiling_decoder_lang and not args.multiling_target_vocab_file: args.multiling_target_vocab_file = [ pytorch_translate_dictionary.default_dictionary_path( save_dir=args.save_dir, dialect=f"trg-{l}") for l in args.multiling_decoder_lang ]
def main(): parser = argparse.ArgumentParser( description="PyTorch Translate - preprocessing") pytorch_translate_options.add_preprocessing_args(parser) args = parser.parse_args() pytorch_translate_options.validate_preprocessing_args(args) preprocess_corpora(args)
def test_validate_preprocessing_args_monolingual_target_only(self): """ Make sure we pass validation with the semisupervised training task when we only have monolingual source data. """ args = self.get_common_data_args_namespace() args.task = "pytorch_translate_semisupervised" args.train_mono_target_binary_path = test_utils.make_temp_file() options.validate_preprocessing_args(args)
def validate_and_set_default_args(args): # Prevents generate from printing individual translated sentences when # calculating BLEU score. args.quiet = True pytorch_translate_options.check_unsupported_fairseq_flags(args) # Set default init method for multi-GPU training if the user didn't specify # them. if args.distributed_world_size > 1: args.distributed_init_method = ( f"tcp://localhost:{random.randint(10000, 20000)}" if not args.distributed_init_method else args.distributed_init_method) if args.local_num_gpus > args.distributed_world_size: raise ValueError( f"--local-num-gpus={args.local_num_gpus} must be " f"<= --distributed-world-size={args.distributed_world_size}.") if args.local_num_gpus > torch.cuda.device_count(): raise ValueError( f"--local-num-gpus={args.local_num_gpus} must be " f"<= the number of GPUs: {torch.cuda.device_count()}.") if args.fp16 and getattr(args, "adversary", False): print( "Warning: disabling fp16 training since it's not supported by AdversarialTrainer." ) args.fp16 = False if not args.source_vocab_file: args.source_vocab_file = pytorch_translate_dictionary.default_dictionary_path( save_dir=args.save_dir, dialect=args.source_lang) if not args.target_vocab_file: args.target_vocab_file = pytorch_translate_dictionary.default_dictionary_path( save_dir=args.save_dir, dialect=args.target_lang) if (args.arch == "char_source" or args.arch == "char_source_transformer" or args.arch == "char_source_hybrid") and not args.char_source_vocab_file: args.char_source_vocab_file = pytorch_translate_dictionary.default_char_dictionary_path( save_dir=args.save_dir, dialect=args.source_lang) pytorch_translate_options.validate_preprocessing_args(args) pytorch_translate_options.validate_generation_args(args) if args.multiling_encoder_lang and not args.multiling_source_vocab_file: args.multiling_source_vocab_file = [ pytorch_translate_dictionary.default_dictionary_path( save_dir=args.save_dir, dialect=f"src-{l}") for l in args.multiling_encoder_lang ] if args.multiling_decoder_lang and not args.multiling_target_vocab_file: args.multiling_target_vocab_file = [ pytorch_translate_dictionary.default_dictionary_path( save_dir=args.save_dir, dialect=f"trg-{l}") for l in args.multiling_decoder_lang ]
def test_validate_preprocessing_args_monolingual(self): """ Make sure we pass validation with the semisupervised training task when the required monolingual source and target data is set. """ args = self.get_common_data_args_namespace() args.task = "pytorch_translate_semisupervised" args.train_mono_source_binary_path = test_utils.make_temp_file() args.train_mono_target_text_file = test_utils.make_temp_file() options.validate_preprocessing_args(args)
def validate_and_set_default_args(args): # Prevents generate from printing individual translated sentences when # calculating BLEU score. args.quiet = True if args.distributed_world_size < torch.cuda.device_count(): raise ValueError( f"--distributed-world-size={args.distributed_world_size} " f"must be >= the number of GPUs: {torch.cuda.device_count()}." ) # Set default init method for multi-GPU training if the user didn't specify # them. if args.distributed_world_size > 1: args.distributed_init_method = ( f"tcp://localhost:{random.randint(10000, 20000)}" if not args.distributed_init_method else args.distributed_init_method ) if not args.source_vocab_file: args.source_vocab_file = pytorch_translate_dictionary.default_dictionary_path( save_dir=args.save_dir, dialect=args.source_lang ) if not args.target_vocab_file: args.target_vocab_file = pytorch_translate_dictionary.default_dictionary_path( save_dir=args.save_dir, dialect=args.target_lang ) if args.arch == "char_source" and not args.char_source_vocab_file: args.char_source_vocab_file = pytorch_translate_dictionary.default_char_dictionary_path( save_dir=args.save_dir, dialect=args.source_lang ) pytorch_translate_options.validate_preprocessing_args(args) pytorch_translate_options.validate_generation_args(args) if args.multiling_encoder_lang and not args.multiling_source_vocab_file: args.multiling_source_vocab_file = [ pytorch_translate_dictionary.default_dictionary_path( save_dir=args.save_dir, dialect=f"src-{l}" ) for l in args.multiling_encoder_lang ] if args.multiling_decoder_lang and not args.multiling_target_vocab_file: args.multiling_target_vocab_file = [ pytorch_translate_dictionary.default_dictionary_path( save_dir=args.save_dir, dialect=f"trg-{l}" ) for l in args.multiling_decoder_lang ]
def validate_args(args): pytorch_translate_options.check_unsupported_fairseq_flags(args) pytorch_translate_options.validate_preprocessing_args(args) pytorch_translate_options.validate_generation_args(args)
def test_validate_preprocesing_args(self): """ Make sure we validation passes with the minimum args required. """ args = self.get_common_data_args_namespace() options.validate_preprocessing_args(args)