def add_argparse_args(cls, parser: ArgumentParser, dest: str = None) -> None: """Add the command-line arguments for this Method to the given parser. Override this if you don't use simple-parsing to add the args. Parameters ---------- parser : ArgumentParser The ArgumentParser. dest : str, optional The 'base' destination where the arguments should be set on the namespace, by default None, in which case the arguments can be at the "root" level on the namespace. """ if is_dataclass(cls): dest = dest or camel_case(cls.__qualname__) parser.add_arguments(cls, dest=dest) elif issubclass(cls, LightningDataModule): # TODO: Test this case out (using a LightningDataModule as a Setting). super().add_argparse_args(parser) # type: ignore else: raise NotImplementedError( f"Don't know how to add command-line arguments for class " f"{cls}, since it isn't a dataclass and doesn't override the " f"`add_argparse_args` method!\n" f"Either make class {cls} a dataclass and add command-line " f"arguments as fields, or add an implementation for the " f"`add_argparse_args` and `from_argparse_args` classmethods.")
def main_sl(): """ Applies the PnnMethod in a SL Setting. """ parser = ArgumentParser(description=__doc__, add_dest_to_option_strings=False) # Add arguments for the Setting # TODO: PNN is coded for the DomainIncrementalSetting, where the action space # is the same for each task. # parser.add_arguments(DomainIncrementalSetting, dest="setting") parser.add_arguments(TaskIncrementalSLSetting, dest="setting") # TaskIncrementalSLSetting.add_argparse_args(parser, dest="setting") Config.add_argparse_args(parser, dest="config") # Add arguments for the Method: PnnMethod.add_argparse_args(parser, dest="method") args = parser.parse_args() # setting: TaskIncrementalSLSetting = args.setting setting: TaskIncrementalSLSetting = TaskIncrementalSLSetting.from_argparse_args( # setting: DomainIncrementalSetting = DomainIncrementalSetting.from_argparse_args( args, dest="setting", ) config: Config = Config.from_argparse_args(args, dest="config") method: PnnMethod = PnnMethod.from_argparse_args(args, dest="method") method.config = config results = setting.apply(method, config=config) print(results.summary()) return results
def parse_setting_and_method_instances( setting: Union[Setting, Type[Setting]], method: Union[Method, Type[Method]], argv: Union[str, List[str]] = None, strict_args: bool = False, ) -> Tuple[Setting, Method]: # TODO: Should we raise an error if an argument appears both in the Setting # and the Method? parser = ArgumentParser(description=__doc__, add_dest_to_option_strings=False) if not isinstance(setting, Setting): assert issubclass(setting, Setting) setting.add_argparse_args(parser) if not isinstance(method, Method): assert method is not None assert issubclass(method, Method) method.add_argparse_args(parser) if strict_args: args = parser.parse_args(argv) else: args, unused_args = parser.parse_known_args(argv) if unused_args: logger.warning(UserWarning(f"Unused command-line args: {unused_args}")) if not isinstance(setting, Setting): setting = setting.from_argparse_args(args) if not isinstance(method, Method): method = method.from_argparse_args(args) return setting, method
def main(): parser = ArgumentParser() parser.add_arguments(Settings, dest='opts') argcomplete.autocomplete(parser) opts = parser.parse_args().opts if opts.train: train(opts) else: test(opts)
def test_passing_instance(): @dataclass class Foo: a: int = 123 parser = ArgumentParser() parser.add_arguments(Foo(456), dest="foo") args = parser.parse_args("") assert args.foo.a == 456, vars(args)
def add_model_specific_args(cls, arg_parser: ArgumentParser): cls.Model.add_model_specific_args(arg_parser) # Remove the '--env' argument: from argparse import Action action: Action env_action_index = [ i for i, action in enumerate(arg_parser._actions) if "--env" in action.option_strings ][0] arg_parser._handle_conflict_resolve( None, [("--env", arg_parser._actions[env_action_index])])
def add_argparse_args(cls, parser: ArgumentParser, dest: str = ""): """Adds command-line arguments for this Method to an argument parser. NOTE: This doesn't do anything differently than the base implementation, but it's included here just for illustration purposes. """ # 'dest' is where the arguments will be stored on the namespace. dest = dest or camel_case(cls.__qualname__) # Add all command-line arguments. This adds arguments for all fields of # this dataclass. parser.add_arguments(cls, dest=dest)
def test_issue_29(): from simple_parsing import ArgumentParser @dataclass class MyCli: asdf: Tuple[str, ...] parser = ArgumentParser() parser.add_arguments(MyCli, dest="args") args = parser.parse_args("--asdf asdf fgfh".split()) assert args.args == MyCli(asdf=("asdf", "fgfh"))
def test_issue62(): import enum from simple_parsing.helpers import list_field from typing import List parser = ArgumentParser() class Color(enum.Enum): RED = "red" ORANGE = "orange" BLUE = "blue" class Temperature(enum.Enum): HOT = 1 WARM = 0 COLD = -1 MONTREAL = -35 @dataclass class MyPreferences(TestSetup): """You can use Enums""" color: Color = Color.BLUE # my favorite colour # a list of colors color_list: List[Color] = list_field(Color.ORANGE) # Some floats. floats: List[float] = list_field(1.1, 2.2, 3.3) # pick a temperature temp: Temperature = Temperature.WARM # a list of temperatures temp_list: List[Temperature] = list_field(Temperature.COLD, Temperature.WARM) parser.add_arguments(MyPreferences, "my_preferences") assert MyPreferences.setup( "--color ORANGE --color_list RED BLUE --temp MONTREAL" ) == MyPreferences( color=Color.ORANGE, color_list=[Color.RED, Color.BLUE], temp=Temperature.MONTREAL, temp_list=[Temperature.COLD, Temperature.WARM], ) assert MyPreferences.setup( "--color ORANGE --color_list RED BLUE --temp MONTREAL --temp_list MONTREAL HOT" ) == MyPreferences( color=Color.ORANGE, color_list=[Color.RED, Color.BLUE], temp=Temperature.MONTREAL, temp_list=[Temperature.MONTREAL, Temperature.HOT], ) assert Temperature["MONTREAL"] is Temperature.MONTREAL assert Temperature(-35) is Temperature.MONTREAL
def to_simple_parsing_args(some_dataclass_type): """ Add this as a classmethod to some dataclass in order to make its arguments accessible from commandline Example: @classmethod def get_args(cls): args: cls = to_simple_parsing_args(cls) return args """ from simple_parsing import ArgumentParser, ConflictResolution parser = ArgumentParser(conflict_resolution=ConflictResolution.NONE) parser.add_arguments(some_dataclass_type, dest='cfg') args: some_dataclass_type = parser.parse_args().cfg return args
def main(): parser = ArgumentParser() parser.add_arguments(Options, dest="options") # Equivalent to: # subparsers = parser.add_subparsers(title="config", required=False) # parser.set_defaults(config=AConfig()) # a_parser = subparsers.add_parser("a", help="A help.") # a_parser.add_arguments(AConfig, dest="config") # b_parser = subparsers.add_parser("b", help="B help.") # b_parser.add_arguments(BConfig, dest="config") args = parser.parse_args() print(args) options: Options = args.options print(options)
def from_known_args(cls, argv: Union[str, List[str]] = None, reorder: bool = True, strict: bool = False) -> Tuple[P, List[str]]: # if not is_dataclass(cls): # raise NotImplementedError( # f"Don't know how to parse an instance of class {cls} from the " # f"command-line, as it isn't a dataclass or doesn't have the " # f"`add_arpargse_args` and `from_argparse_args` classmethods. " # f"You'll have to override the `from_known_args` classmethod." # ) if argv is None: argv = sys.argv[1:] logger.debug(f"parsing an instance of class {cls} from argv {argv}") if isinstance(argv, str): argv = shlex.split(argv) parser = ArgumentParser(description=cls.__doc__, add_dest_to_option_strings=False) cls.add_argparse_args(parser) # TODO: Set temporarily on the class, so its accessible in the class constructor cls_argv = cls._argv cls._argv = argv instance: P if strict: args = parser.parse_args(argv) unused_args = [] else: args, unused_args = parser.parse_known_args( argv, attempt_to_reorder=reorder) if unused_args: logger.debug( RuntimeWarning( f"Unknown/unused args when parsing class {cls}: {unused_args}" )) instance = cls.from_argparse_args(args) # Save the argv that were used to create the instance on its `_argv` # attribute. instance._argv = argv cls._argv = cls_argv return instance, unused_args
def parse_hparams_and_config() -> Tuple[HParams, MnistConfig]: # Training settings parser = ArgumentParser(description='PyTorch MNIST Example') parser.add_arguments(HParams, "hparams", default=HParams()) parser.add_arguments(MnistConfig, "config", default=MnistConfig()) args, _ = parser.parse_known_args() hparams: HParams = args.hparams config: MnistConfig = args.config return hparams, config
def main_rl(): """ Applies the PnnMethod in a RL Setting. """ parser = ArgumentParser(description=__doc__, add_dest_to_option_strings=False) Config.add_argparse_args(parser, dest="config") PnnMethod.add_argparse_args(parser, dest="method") # Haven't tested with observe_state_directly=False # it run but I don't know if it converge setting = TaskIncrementalRLSetting( dataset="cartpole", observe_state_directly=True, nb_tasks=2, train_task_schedule={ 0: { "gravity": 10, "length": 0.3 }, 1000: { "gravity": 10, "length": 0.5 }, }, ) args = parser.parse_args() config: Config = Config.from_argparse_args(args, dest="config") method: PnnMethod = PnnMethod.from_argparse_args(args, dest="method") method.config = config # 2. Creating the Method # method = ImproveMethod() # 3. Applying the method to the setting: results = setting.apply(method, config=config) print(results.summary()) print(f"objective: {results.objective}") return results
def test_arg_and_dataclass_with_same_name(silent): @dataclass class SomeClass: a: int = 1 # some docstring for attribute 'a' parser = ArgumentParser() parser.add_argument("--a", default=123) with raises(argparse.ArgumentError): parser.add_arguments(SomeClass, dest="some_class") args = parser.parse_args("")
def test_cmd_false_doesnt_create_conflicts(): @dataclass class A: batch_size: int = field(default=10, cmd=False) @dataclass class B: batch_size: int = 20 # @dataclass # class Foo(TestSetup): # a: A = mutable_field(A) # b: B = mutable_field(B) parser = ArgumentParser(conflict_resolution=ConflictResolution.NONE) parser.add_arguments(A, "a") parser.add_arguments(B, "b") args = parser.parse_args("--batch_size 32".split()) a: A = args.a b: B = args.b assert a == A() assert b == B(batch_size=32)
def update_settings(opts: dataclass, argv: List[str] = None): """ Update given settings from command line arguments. Uses `argparse`, `argcomplete` and `simple_parsing` under the hood. """ if not is_dataclass(opts): raise ValueError('Cannot update args on non-dataclass class') # Use default system argv if not supplied. argv = sys.argv[1:] if argv is None else argv # Update from config file, if applicable. parser_parents = [] if isinstance(opts, Serializable): opts, argv, parser_parents = _update_settings_from_file(opts, argv) parser = ArgumentParser(parents=parser_parents) parser.add_arguments(type(opts), dest='opts', default=opts) if use_argcomplete: argcomplete.autocomplete(parser) args = parser.parse_args(argv) return args.opts
def test_vanilla_argparse_issue64(): """This test shows that the ArgumentDefaultsHelpFormatter of argparse doesn't add the "(default: xyz)" if the 'help' argument isn't already passed! This begs the question: Should simple-parsing add a 'help' argument always, so that the formatter can then add the default string after? """ import argparse parser = ArgumentParser( "issue64", formatter_class=argparse.ArgumentDefaultsHelpFormatter) group = parser.add_argument_group("Options ['options']", description="These are the options") group.add_argument("--foo", type=str, metavar="str", default="aaa", help="Description") group.add_argument("--bar", type=str, metavar="str", default="bbb") from io import StringIO s = StringIO() parser.print_help(file=s) s.seek(0) assert s.read() == textwrap.dedent("""\ usage: issue64 [-h] [--foo str] [--bar str] optional arguments: -h, --help show this help message and exit Options ['options']: These are the options --foo str Description (default: aaa) --bar str """)
def demo(): """ Runs the EwcMethod on a simple setting, just to check that it works fine. """ # Adding arguments for each group directly: parser = ArgumentParser(description=__doc__) EwcMethod.add_argparse_args(parser, dest="method") parser.add_arguments(Config, "config") args = parser.parse_args() method = EwcMethod.from_argparse_args(args, dest="method") config: Config = args.config task_schedule = { 0: { "gravity": 10, "length": 0.2 }, 1000: { "gravity": 100, "length": 1.2 }, # 2000: {"gravity": 10, "length": 0.2}, } setting = TaskIncrementalRLSetting( dataset="cartpole", train_task_schedule=task_schedule, test_task_schedule=task_schedule, observe_state_directly=True, # max_steps=1000, ) # from sequoia.settings import TaskIncrementalSetting, ClassIncrementalSetting # setting = ClassIncrementalSetting(dataset="mnist", nb_tasks=5) # setting = TaskIncrementalSetting(dataset="mnist", nb_tasks=5) results = setting.apply(method, config=config) print(results.summary())
def test_multiple_at_same_dest_throws_error(): @dataclass class SomeClass: a: int = 123 parser = ArgumentParser() parser.add_arguments(SomeClass, "some_class") with raises(argparse.ArgumentError): parser.add_arguments(SomeClass, "some_class")
def baseline_demo_command_line(): parser = ArgumentParser(__doc__, add_dest_to_option_strings=False) # Supervised Learning Setting: parser.add_arguments(TaskIncrementalSetting, dest="setting") # Reinforcement Learning Setting: parser.add_arguments(TaskIncrementalRLSetting, dest="setting") parser.add_arguments(Config, dest="config") BaselineMethod.add_argparse_args(parser, dest="method") args = parser.parse_args() setting: Setting = args.setting config: Config = args.config method: BaselineMethod = BaselineMethod.from_argparse_args(args, dest="method") results = setting.apply(method, config=config) print(results.summary()) return results
def get_args(): parser = ArgumentParser() parser.add_argument('--runner', default='test', choices=runner_registry.keys(), help="Specify whether to train or test an agent") parser.add_argument('--environment', default="car", choices=environment_registry.keys(), help="Environment to train/test agent on") parser.add_argument('--agent', default='scp', choices=agent_registry.keys(), help='Agent to train/test') args, _ = parser.parse_known_args() runner_class = runner_registry[args.runner] agent_class = agent_registry[args.agent] environment_class = environment_registry[args.environment] parser = agent_class.add_argparse_args(parser) parser = environment_class.add_argparse_args(parser) parser = runner_class.add_argparse_args(parser) return parser.parse_args()
def test_vanilla_argparse_beheviour(self, options: Dict[str, Any], passed_arg: str, expected_value: Any): parser = ArgumentParser() parser.add_argument("--foo", **options) if passed_arg is DONT_PASS: args = parser.parse_args("") else: args = parser.parse_args(shlex.split("--foo " + passed_arg)) foo = args.foo assert foo == expected_value
def test_store_false_action(): parser = ArgumentParser(add_option_string_dash_variants=True) parser.add_arguments(Foo, "foo") args = parser.parse_args("--no-cache".split()) foo: Foo = args.foo assert foo.no_cache == False args = parser.parse_args("".split()) foo: Foo = args.foo assert foo.no_cache == True
def test_parsing_twice(): @dataclass class Foo: a: int = 123 parser = ArgumentParser() parser.add_arguments(Foo, dest="foo") args = parser.parse_args("") assert args.foo.a == 123, vars(args) args = parser.parse_args("--a 456".split()) assert args.foo.a == 456, vars(args)
def test_works_fine_with_other_argparse_arguments(simple_attribute, silent): some_type, passed_value, expected_value = simple_attribute @dataclass class SomeClass: a: some_type # type: ignore """some docstring for attribute 'a'""" parser = ArgumentParser() parser.add_argument("--x", type=int) parser.add_arguments(SomeClass, dest="some_class") x = 123 args = parser.parse_args(shlex.split(f"--x {x} --a {passed_value}")) assert args == argparse.Namespace(some_class=SomeClass(a=expected_value), x=x)
def test_arg_and_dataclass_with_same_name_after_prefixing(silent): @dataclass class SomeClass: a: int = 1 # some docstring for attribute 'a' @dataclass class Parent: pre: SomeClass = SomeClass() bla: SomeClass = SomeClass() parser = ArgumentParser() parser.add_argument("--pre.a", default=123, type=int) with raises(argparse.ArgumentError): parser.add_arguments(Parent, dest="some_class") args = parser.parse_args("--pre.a 123 --pre.a 456".split())
def test_reproduce(self): parser = ArgumentParser() parser.add_arguments(MyConfig, dest="cfg") args_none, _ = parser.parse_known_args([]) args, extra = parser.parse_known_args(["--values", "3", "4"]) values = args.cfg.values # This is what we'd expect: # assert values == (3, 4) # assert extra == [] # But instead we get this: assert values == (3) assert extra == ["4"]
def test_optional_list(self, passed_arg: str, expected_value: Any): parser = ArgumentParser() @dataclass class MyConfig: foo: Optional[List[str]] = None parser.add_arguments(MyConfig, dest="config") if passed_arg is DONT_PASS: args = parser.parse_args("") else: args = parser.parse_args(shlex.split("--foo " + passed_arg)) assert args.config.foo == expected_value
def test_experiments(): from abc import ABC @dataclass class Experiment(ABC): dataset: str iid: bool = True @dataclass class Mnist(Experiment): dataset: str = "mnist" iid: bool = True @dataclass class MnistContinual(Experiment): dataset: str = "mnist" iid: bool = False @dataclass class Config: experiment: Experiment = subparsers({ "mnist": Mnist, "mnist_continual": MnistContinual, }) for field in dataclasses.fields(Config): assert simple_parsing.utils.is_subparser_field(field), field parser = ArgumentParser() parser.add_arguments(Config, "config") with raises(SystemExit): args = parser.parse_args() args = parser.parse_args("mnist".split()) experiment = args.config.experiment assert isinstance(experiment, Mnist) assert experiment.dataset == "mnist" assert experiment.iid == True args = parser.parse_args("mnist_continual".split()) experiment = args.config.experiment assert isinstance(experiment, MnistContinual) assert experiment.dataset == "mnist" assert experiment.iid == False