示例#1
0
def validate_and_set_default_args(args):
    # Prevents generate from printing individual translated sentences when
    # calculating BLEU score.
    args.quiet = True

    if not args.source_vocab_file:
        args.source_vocab_file = pytorch_translate_dictionary.default_dictionary_path(
            save_dir=args.save_dir, dialect=args.source_lang)
    if not args.target_vocab_file:
        args.target_vocab_file = pytorch_translate_dictionary.default_dictionary_path(
            save_dir=args.save_dir, dialect=args.target_lang)

    if args.arch == "char_source" and not args.char_source_vocab_file:
        args.char_source_vocab_file = pytorch_translate_dictionary.default_char_dictionary_path(
            save_dir=args.save_dir, dialect=args.source_lang)

    pytorch_translate_options.validate_preprocessing_args(args)
    pytorch_translate_options.validate_generation_args(args)
    if args.multiling_encoder_lang and not args.multiling_source_vocab_file:
        args.multiling_source_vocab_file = [
            pytorch_translate_dictionary.default_dictionary_path(
                save_dir=args.save_dir, dialect=f"src-{l}")
            for l in args.multiling_encoder_lang
        ]
    if args.multiling_decoder_lang and not args.multiling_target_vocab_file:
        args.multiling_target_vocab_file = [
            pytorch_translate_dictionary.default_dictionary_path(
                save_dir=args.save_dir, dialect=f"trg-{l}")
            for l in args.multiling_decoder_lang
        ]
def main():
    parser = argparse.ArgumentParser(
        description="PyTorch Translate - preprocessing")
    pytorch_translate_options.add_preprocessing_args(parser)
    args = parser.parse_args()
    pytorch_translate_options.validate_preprocessing_args(args)
    preprocess_corpora(args)
示例#3
0
 def test_validate_preprocessing_args_monolingual_target_only(self):
     """
     Make sure we pass validation with the semisupervised training
     task when we only have monolingual source data.
     """
     args = self.get_common_data_args_namespace()
     args.task = "pytorch_translate_semisupervised"
     args.train_mono_target_binary_path = test_utils.make_temp_file()
     options.validate_preprocessing_args(args)
示例#4
0
def validate_and_set_default_args(args):
    # Prevents generate from printing individual translated sentences when
    # calculating BLEU score.
    args.quiet = True

    pytorch_translate_options.check_unsupported_fairseq_flags(args)

    # Set default init method for multi-GPU training if the user didn't specify
    # them.
    if args.distributed_world_size > 1:
        args.distributed_init_method = (
            f"tcp://localhost:{random.randint(10000, 20000)}" if
            not args.distributed_init_method else args.distributed_init_method)

        if args.local_num_gpus > args.distributed_world_size:
            raise ValueError(
                f"--local-num-gpus={args.local_num_gpus} must be "
                f"<= --distributed-world-size={args.distributed_world_size}.")
        if args.local_num_gpus > torch.cuda.device_count():
            raise ValueError(
                f"--local-num-gpus={args.local_num_gpus} must be "
                f"<= the number of GPUs: {torch.cuda.device_count()}.")

    if args.fp16 and getattr(args, "adversary", False):
        print(
            "Warning: disabling fp16 training since it's not supported by AdversarialTrainer."
        )
        args.fp16 = False

    if not args.source_vocab_file:
        args.source_vocab_file = pytorch_translate_dictionary.default_dictionary_path(
            save_dir=args.save_dir, dialect=args.source_lang)
    if not args.target_vocab_file:
        args.target_vocab_file = pytorch_translate_dictionary.default_dictionary_path(
            save_dir=args.save_dir, dialect=args.target_lang)

    if (args.arch == "char_source" or args.arch == "char_source_transformer"
            or args.arch
            == "char_source_hybrid") and not args.char_source_vocab_file:
        args.char_source_vocab_file = pytorch_translate_dictionary.default_char_dictionary_path(
            save_dir=args.save_dir, dialect=args.source_lang)

    pytorch_translate_options.validate_preprocessing_args(args)
    pytorch_translate_options.validate_generation_args(args)
    if args.multiling_encoder_lang and not args.multiling_source_vocab_file:
        args.multiling_source_vocab_file = [
            pytorch_translate_dictionary.default_dictionary_path(
                save_dir=args.save_dir, dialect=f"src-{l}")
            for l in args.multiling_encoder_lang
        ]
    if args.multiling_decoder_lang and not args.multiling_target_vocab_file:
        args.multiling_target_vocab_file = [
            pytorch_translate_dictionary.default_dictionary_path(
                save_dir=args.save_dir, dialect=f"trg-{l}")
            for l in args.multiling_decoder_lang
        ]
示例#5
0
 def test_validate_preprocessing_args_monolingual(self):
     """
     Make sure we pass validation with the semisupervised training
     task when the required monolingual source and target data is
     set.
     """
     args = self.get_common_data_args_namespace()
     args.task = "pytorch_translate_semisupervised"
     args.train_mono_source_binary_path = test_utils.make_temp_file()
     args.train_mono_target_text_file = test_utils.make_temp_file()
     options.validate_preprocessing_args(args)
示例#6
0
def validate_and_set_default_args(args):
    # Prevents generate from printing individual translated sentences when
    # calculating BLEU score.
    args.quiet = True

    if args.distributed_world_size < torch.cuda.device_count():
        raise ValueError(
            f"--distributed-world-size={args.distributed_world_size} "
            f"must be >= the number of GPUs: {torch.cuda.device_count()}."
        )
    # Set default init method for multi-GPU training if the user didn't specify
    # them.
    if args.distributed_world_size > 1:
        args.distributed_init_method = (
            f"tcp://localhost:{random.randint(10000, 20000)}"
            if not args.distributed_init_method
            else args.distributed_init_method
        )

    if not args.source_vocab_file:
        args.source_vocab_file = pytorch_translate_dictionary.default_dictionary_path(
            save_dir=args.save_dir, dialect=args.source_lang
        )
    if not args.target_vocab_file:
        args.target_vocab_file = pytorch_translate_dictionary.default_dictionary_path(
            save_dir=args.save_dir, dialect=args.target_lang
        )

    if args.arch == "char_source" and not args.char_source_vocab_file:
        args.char_source_vocab_file = pytorch_translate_dictionary.default_char_dictionary_path(
            save_dir=args.save_dir, dialect=args.source_lang
        )

    pytorch_translate_options.validate_preprocessing_args(args)
    pytorch_translate_options.validate_generation_args(args)
    if args.multiling_encoder_lang and not args.multiling_source_vocab_file:
        args.multiling_source_vocab_file = [
            pytorch_translate_dictionary.default_dictionary_path(
                save_dir=args.save_dir, dialect=f"src-{l}"
            )
            for l in args.multiling_encoder_lang
        ]
    if args.multiling_decoder_lang and not args.multiling_target_vocab_file:
        args.multiling_target_vocab_file = [
            pytorch_translate_dictionary.default_dictionary_path(
                save_dir=args.save_dir, dialect=f"trg-{l}"
            )
            for l in args.multiling_decoder_lang
        ]
示例#7
0
文件: train.py 项目: dwraft/translate
def validate_args(args):
    pytorch_translate_options.check_unsupported_fairseq_flags(args)
    pytorch_translate_options.validate_preprocessing_args(args)
    pytorch_translate_options.validate_generation_args(args)
示例#8
0
 def test_validate_preprocesing_args(self):
     """ Make sure we validation passes with the minimum args required. """
     args = self.get_common_data_args_namespace()
     options.validate_preprocessing_args(args)