示例#1
0
def parse_args_and_arch(
    parser: argparse.ArgumentParser,
    input_args: List[str] = None,
    parse_known: bool = False,
    suppress_defaults: bool = False,
    modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None,
):
    """
    Args:
        parser (ArgumentParser): the parser
        input_args (List[str]): strings to parse, defaults to sys.argv
        parse_known (bool): only parse known arguments, similar to
            `ArgumentParser.parse_known_args`
        suppress_defaults (bool): parse while ignoring all default values
        modify_parser (Optional[Callable[[ArgumentParser], None]]):
            function to modify the parser, e.g., to set default values
    """
    if suppress_defaults:
        # Parse args without any default values. This requires us to parse
        # twice, once to identify all the necessary task/model args, and a second
        # time with all defaults set to None.
        args = parse_args_and_arch(
            parser,
            input_args=input_args,
            parse_known=parse_known,
            suppress_defaults=False,
        )
        suppressed_parser = argparse.ArgumentParser(add_help=False, parents=[parser])
        suppressed_parser.set_defaults(**{k: None for k, v in vars(args).items()})
        args = suppressed_parser.parse_args(input_args)
        return argparse.Namespace(
            **{k: v for k, v in vars(args).items() if v is not None}
        )

    from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY, MODEL_REGISTRY

    # Before creating the true parser, we need to import optional user module
    # in order to eagerly import custom tasks, optimizers, architectures, etc.
    usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
    usr_parser.add_argument("--user-dir", default=None)
    usr_args, _ = usr_parser.parse_known_args(input_args)
    utils.import_user_module(usr_args)

    if modify_parser is not None:
        modify_parser(parser)

    # The parser doesn't know about model/criterion/optimizer-specific args, so
    # we parse twice. First we parse the model/criterion/optimizer, then we
    # parse a second time after adding the *-specific arguments.
    # If input_args is given, we will parse those args instead of sys.argv.
    args, _ = parser.parse_known_args(input_args)

    # Add model-specific args to parser.
    if hasattr(args, "arch"):
        model_specific_group = parser.add_argument_group(
            "Model-specific configuration",
            # Only include attributes which are explicitly given as command-line
            # arguments or which have default values.
            argument_default=argparse.SUPPRESS,
        )
        if args.arch in ARCH_MODEL_REGISTRY:
            ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group)
        elif args.arch in MODEL_REGISTRY:
            MODEL_REGISTRY[args.arch].add_args(model_specific_group)
        else:
            raise RuntimeError()

    if hasattr(args, "task"):
        from fairseq.tasks import TASK_REGISTRY

        TASK_REGISTRY[args.task].add_args(parser)
    if getattr(args, "use_bmuf", False):
        # hack to support extra args for block distributed data parallelism
        from fairseq.optim.bmuf import FairseqBMUF

        FairseqBMUF.add_args(parser)

    # Add *-specific args to parser.
    from fairseq.registry import REGISTRIES

    for registry_name, REGISTRY in REGISTRIES.items():
        choice = getattr(args, registry_name, None)
        if choice is not None:
            cls = REGISTRY["registry"][choice]
            if hasattr(cls, "add_args"):
                cls.add_args(parser)
            elif hasattr(cls, "__dataclass"):
                gen_parser_from_dataclass(parser, cls.__dataclass())

    # Modify the parser a second time, since defaults may have been reset
    if modify_parser is not None:
        modify_parser(parser)

    # Parse a second time.
    if parse_known:
        args, extra = parser.parse_known_args(input_args)
    else:
        args = parser.parse_args(input_args)
        extra = None
    # Post-process args.
    if (
        hasattr(args, "batch_size_valid") and args.batch_size_valid is None
    ) or not hasattr(args, "batch_size_valid"):
        args.batch_size_valid = args.batch_size
    if hasattr(args, "max_tokens_valid") and args.max_tokens_valid is None:
        args.max_tokens_valid = args.max_tokens
    if getattr(args, "memory_efficient_fp16", False):
        args.fp16 = True
    if getattr(args, "memory_efficient_bf16", False):
        args.bf16 = True
    args.tpu = getattr(args, "tpu", False)
    args.bf16 = getattr(args, "bf16", False)
    if args.bf16:
        args.tpu = True
    if args.tpu and args.fp16:
        raise ValueError("Cannot combine --fp16 and --tpu, use --bf16 on TPUs")

    if getattr(args, "seed", None) is None:
        args.seed = 1  # default seed for training
        args.no_seed_provided = True
    else:
        args.no_seed_provided = False

    # Apply architecture configuration.
    if hasattr(args, "arch") and args.arch in ARCH_CONFIG_REGISTRY:
        ARCH_CONFIG_REGISTRY[args.arch](args)

    if parse_known:
        return args, extra
    else:
        return args
示例#2
0
文件: bmuf.py 项目: Tvicker/espresso
 def add_args(parser):
     """Add optimizer-specific arguments to the parser."""
     gen_parser_from_dataclass(parser, FairseqBMUFConfig())
示例#3
0
 def add_args(cls, parser):
     """Add task-specific arguments to the parser."""
     dc = getattr(cls, "__dataclass", None)
     if dc is not None:
         gen_parser_from_dataclass(parser, dc())
示例#4
0
def add_dataset_args(parser, train=False, gen=False):
    group = parser.add_argument_group("dataset_data_loading")
    gen_parser_from_dataclass(group, DatasetConfig())
    # fmt: on
    return group
示例#5
0
def add_generation_args(parser):
    group = parser.add_argument_group("Generation")
    add_common_eval_args(group)
    gen_parser_from_dataclass(group, GenerationConfig())
    return group
示例#6
0
def add_common_eval_args(group):
    gen_parser_from_dataclass(group, CommonEvalParams())
示例#7
0
 def add_args(parser):
     """Add task-specific arguments to the parser. optionaly register config store"""
     gen_parser_from_dataclass(
         parser, SubsampledCrossEntropyWithAccuracyCriterionConfig())
示例#8
0
def add_common_eval_args(group):
    gen_parser_from_dataclass(group, CommonEvalConfig())
 def add_args(parser):
     """Add arguments to the parser for this LR scheduler."""
     gen_parser_from_dataclass(parser, InverseSquareRootScheduleConfig())
示例#10
0
def add_ema_args(parser):
    group = parser.add_argument_group("EMA configuration")
    gen_parser_from_dataclass(group, EMAConfig())
示例#11
0
 def add_args(cls, parser):
     """Add model-specific arguments to the parser."""
     dc = getattr(cls, "__dataclass", None)
     if dc is not None:
         # do not set defaults so that settings defaults from various architectures still works
         gen_parser_from_dataclass(parser, dc(), delete_default=True)
示例#12
0
 def add_args(cls, parser):
     """Add criterion-specific arguments to the parser."""
     dc = getattr(cls, '__dataclass', None)
     if dc is not None:
         gen_parser_from_dataclass(parser, dc())
示例#13
0
def add_pruning_args(parser):
    group = parser.add_argument_group("Pruning")
    gen_parser_from_dataclass(group, PruningConfig())
    return group
示例#14
0
def add_optimization_args(parser):
    group = parser.add_argument_group("optimization")
    # fmt: off
    gen_parser_from_dataclass(group, OptimizationConfig())
    # fmt: on
    return group
示例#15
0
 def test_argparse_convert_basic(self):
     parser = ArgumentParser()
     gen_parser_from_dataclass(parser, A(), True)
     args = parser.parse_args(["--num-layers", '10', "the/data/path"])
     self.assertEqual(args.num_layers, 10)
     self.assertEqual(args.data, "the/data/path")
示例#16
0
def add_checkpoint_args(parser):
    group = parser.add_argument_group("checkpoint")
    # fmt: off
    gen_parser_from_dataclass(group, CheckpointConfig())
    # fmt: on
    return group
示例#17
0
 def add_args(cls, parser):
     """Add arguments to the parser for this LR scheduler."""
     dc = getattr(cls, "__dataclass", None)
     if dc is not None:
         gen_parser_from_dataclass(parser, dc())
示例#18
0
def add_eval_lm_args(parser):
    group = parser.add_argument_group("LM Evaluation")
    add_common_eval_args(group)
    gen_parser_from_dataclass(group, EvalLMConfig())
示例#19
0
 def add_args(parser):
     """Add task-specific arguments to the parser. optionaly register config store"""
     gen_parser_from_dataclass(parser, LanguageModelingConfig())
示例#20
0
def add_interactive_args(parser):
    group = parser.add_argument_group("Interactive")
    gen_parser_from_dataclass(group, InteractiveConfig())
示例#21
0
 def add_args(parser):
     """Add task-specific arguments to the parser. optionaly register config store"""
     gen_parser_from_dataclass(parser, AdaptiveLossConfig())