def test_parsing_twice(): @dataclass class Foo: a: int = 123 parser = ArgumentParser() parser.add_arguments(Foo, dest="foo") args = parser.parse_args("") assert args.foo.a == 123, vars(args) args = parser.parse_args("--a 456".split()) assert args.foo.a == 456, vars(args)
def test_vanilla_argparse_beheviour(self, options: Dict[str, Any], passed_arg: str, expected_value: Any): parser = ArgumentParser() parser.add_argument("--foo", **options) if passed_arg is DONT_PASS: args = parser.parse_args("") else: args = parser.parse_args(shlex.split("--foo " + passed_arg)) foo = args.foo assert foo == expected_value
def test_store_false_action(): parser = ArgumentParser(add_option_string_dash_variants=True) parser.add_arguments(Foo, "foo") args = parser.parse_args("--no-cache".split()) foo: Foo = args.foo assert foo.no_cache == False args = parser.parse_args("".split()) foo: Foo = args.foo assert foo.no_cache == True
def test_optional_list(self, passed_arg: str, expected_value: Any): parser = ArgumentParser() @dataclass class MyConfig: foo: Optional[List[str]] = None parser.add_arguments(MyConfig, dest="config") if passed_arg is DONT_PASS: args = parser.parse_args("") else: args = parser.parse_args(shlex.split("--foo " + passed_arg)) assert args.config.foo == expected_value
def test_optional_seed(): """Test that a value marked as Optional works fine. (Reproduces https://github.com/lebrice/SimpleParsing/issues/14#issue-562538623) """ parser = ArgumentParser() parser.add_arguments(Config, dest="config") args = parser.parse_args("".split()) config: Config = args.config assert config == Config() args = parser.parse_args("--seed 123".split()) config: Config = args.config assert config == Config(123)
def main_sl(): """ Applies the PnnMethod in a SL Setting. """ parser = ArgumentParser(description=__doc__, add_dest_to_option_strings=False) # Add arguments for the Setting # TODO: PNN is coded for the DomainIncrementalSetting, where the action space # is the same for each task. # parser.add_arguments(DomainIncrementalSetting, dest="setting") parser.add_arguments(TaskIncrementalSLSetting, dest="setting") # TaskIncrementalSLSetting.add_argparse_args(parser, dest="setting") Config.add_argparse_args(parser, dest="config") # Add arguments for the Method: PnnMethod.add_argparse_args(parser, dest="method") args = parser.parse_args() # setting: TaskIncrementalSLSetting = args.setting setting: TaskIncrementalSLSetting = TaskIncrementalSLSetting.from_argparse_args( # setting: DomainIncrementalSetting = DomainIncrementalSetting.from_argparse_args( args, dest="setting", ) config: Config = Config.from_argparse_args(args, dest="config") method: PnnMethod = PnnMethod.from_argparse_args(args, dest="method") method.config = config results = setting.apply(method, config=config) print(results.summary()) return results
def parse_setting_and_method_instances( setting: Union[Setting, Type[Setting]], method: Union[Method, Type[Method]], argv: Union[str, List[str]] = None, strict_args: bool = False, ) -> Tuple[Setting, Method]: # TODO: Should we raise an error if an argument appears both in the Setting # and the Method? parser = ArgumentParser(description=__doc__, add_dest_to_option_strings=False) if not isinstance(setting, Setting): assert issubclass(setting, Setting) setting.add_argparse_args(parser) if not isinstance(method, Method): assert method is not None assert issubclass(method, Method) method.add_argparse_args(parser) if strict_args: args = parser.parse_args(argv) else: args, unused_args = parser.parse_known_args(argv) if unused_args: logger.warning(UserWarning(f"Unused command-line args: {unused_args}")) if not isinstance(setting, Setting): setting = setting.from_argparse_args(args) if not isinstance(method, Method): method = method.from_argparse_args(args) return setting, method
def test_experiments(): from abc import ABC @dataclass class Experiment(ABC): dataset: str iid: bool = True @dataclass class Mnist(Experiment): dataset: str = "mnist" iid: bool = True @dataclass class MnistContinual(Experiment): dataset: str = "mnist" iid: bool = False @dataclass class Config: experiment: Experiment = subparsers({ "mnist": Mnist, "mnist_continual": MnistContinual, }) for field in dataclasses.fields(Config): assert simple_parsing.utils.is_subparser_field(field), field parser = ArgumentParser() parser.add_arguments(Config, "config") with raises(SystemExit): args = parser.parse_args() args = parser.parse_args("mnist".split()) experiment = args.config.experiment assert isinstance(experiment, Mnist) assert experiment.dataset == "mnist" assert experiment.iid == True args = parser.parse_args("mnist_continual".split()) experiment = args.config.experiment assert isinstance(experiment, MnistContinual) assert experiment.dataset == "mnist" assert experiment.iid == False
def main(): parser = ArgumentParser() parser.add_arguments(Settings, dest='opts') argcomplete.autocomplete(parser) opts = parser.parse_args().opts if opts.train: train(opts) else: test(opts)
def test_passing_instance(): @dataclass class Foo: a: int = 123 parser = ArgumentParser() parser.add_arguments(Foo(456), dest="foo") args = parser.parse_args("") assert args.foo.a == 456, vars(args)
def test_arg_and_dataclass_with_same_name(silent): @dataclass class SomeClass: a: int = 1 # some docstring for attribute 'a' parser = ArgumentParser() parser.add_argument("--a", default=123) with raises(argparse.ArgumentError): parser.add_arguments(SomeClass, dest="some_class") args = parser.parse_args("")
def test_issue_47_is_fixed( self, field_type: Type, passed_arg: Union[str, object], expected_value: Any, ): parser = ArgumentParser() @dataclass class MyConfig: values: field_type parser.add_arguments(MyConfig, dest="config") if passed_arg is DONT_PASS: args = parser.parse_args([]) else: args = parser.parse_args(shlex.split("--values " + passed_arg)) actual_values = args.config.values assert actual_values == expected_value
def test_issue_29(): from simple_parsing import ArgumentParser @dataclass class MyCli: asdf: Tuple[str, ...] parser = ArgumentParser() parser.add_arguments(MyCli, dest="args") args = parser.parse_args("--asdf asdf fgfh".split()) assert args.args == MyCli(asdf=("asdf", "fgfh"))
def test_arg_and_dataclass_with_same_name_after_prefixing(silent): @dataclass class SomeClass: a: int = 1 # some docstring for attribute 'a' @dataclass class Parent: pre: SomeClass = SomeClass() bla: SomeClass = SomeClass() parser = ArgumentParser() parser.add_argument("--pre.a", default=123, type=int) with raises(argparse.ArgumentError): parser.add_arguments(Parent, dest="some_class") args = parser.parse_args("--pre.a 123 --pre.a 456".split())
def test_works_fine_with_other_argparse_arguments(simple_attribute, silent): some_type, passed_value, expected_value = simple_attribute @dataclass class SomeClass: a: some_type # type: ignore """some docstring for attribute 'a'""" parser = ArgumentParser() parser.add_argument("--x", type=int) parser.add_arguments(SomeClass, dest="some_class") x = 123 args = parser.parse_args(shlex.split(f"--x {x} --a {passed_value}")) assert args == argparse.Namespace(some_class=SomeClass(a=expected_value), x=x)
def to_simple_parsing_args(some_dataclass_type): """ Add this as a classmethod to some dataclass in order to make its arguments accessible from commandline Example: @classmethod def get_args(cls): args: cls = to_simple_parsing_args(cls) return args """ from simple_parsing import ArgumentParser, ConflictResolution parser = ArgumentParser(conflict_resolution=ConflictResolution.NONE) parser.add_arguments(some_dataclass_type, dest='cfg') args: some_dataclass_type = parser.parse_args().cfg return args
def main(): parser = ArgumentParser() parser.add_arguments(Options, dest="options") # Equivalent to: # subparsers = parser.add_subparsers(title="config", required=False) # parser.set_defaults(config=AConfig()) # a_parser = subparsers.add_parser("a", help="A help.") # a_parser.add_arguments(AConfig, dest="config") # b_parser = subparsers.add_parser("b", help="B help.") # b_parser.add_arguments(BConfig, dest="config") args = parser.parse_args() print(args) options: Options = args.options print(options)
def from_known_args(cls, argv: Union[str, List[str]] = None, reorder: bool = True, strict: bool = False) -> Tuple[P, List[str]]: # if not is_dataclass(cls): # raise NotImplementedError( # f"Don't know how to parse an instance of class {cls} from the " # f"command-line, as it isn't a dataclass or doesn't have the " # f"`add_arpargse_args` and `from_argparse_args` classmethods. " # f"You'll have to override the `from_known_args` classmethod." # ) if argv is None: argv = sys.argv[1:] logger.debug(f"parsing an instance of class {cls} from argv {argv}") if isinstance(argv, str): argv = shlex.split(argv) parser = ArgumentParser(description=cls.__doc__, add_dest_to_option_strings=False) cls.add_argparse_args(parser) # TODO: Set temporarily on the class, so its accessible in the class constructor cls_argv = cls._argv cls._argv = argv instance: P if strict: args = parser.parse_args(argv) unused_args = [] else: args, unused_args = parser.parse_known_args( argv, attempt_to_reorder=reorder) if unused_args: logger.debug( RuntimeWarning( f"Unknown/unused args when parsing class {cls}: {unused_args}" )) instance = cls.from_argparse_args(args) # Save the argv that were used to create the instance on its `_argv` # attribute. instance._argv = argv cls._argv = cls_argv return instance, unused_args
def baseline_demo_command_line(): parser = ArgumentParser(__doc__, add_dest_to_option_strings=False) # Supervised Learning Setting: parser.add_arguments(TaskIncrementalSetting, dest="setting") # Reinforcement Learning Setting: parser.add_arguments(TaskIncrementalRLSetting, dest="setting") parser.add_arguments(Config, dest="config") BaselineMethod.add_argparse_args(parser, dest="method") args = parser.parse_args() setting: Setting = args.setting config: Config = args.config method: BaselineMethod = BaselineMethod.from_argparse_args(args, dest="method") results = setting.apply(method, config=config) print(results.summary()) return results
def main_rl(): """ Applies the PnnMethod in a RL Setting. """ parser = ArgumentParser(description=__doc__, add_dest_to_option_strings=False) Config.add_argparse_args(parser, dest="config") PnnMethod.add_argparse_args(parser, dest="method") # Haven't tested with observe_state_directly=False # it run but I don't know if it converge setting = TaskIncrementalRLSetting( dataset="cartpole", observe_state_directly=True, nb_tasks=2, train_task_schedule={ 0: { "gravity": 10, "length": 0.3 }, 1000: { "gravity": 10, "length": 0.5 }, }, ) args = parser.parse_args() config: Config = Config.from_argparse_args(args, dest="config") method: PnnMethod = PnnMethod.from_argparse_args(args, dest="method") method.config = config # 2. Creating the Method # method = ImproveMethod() # 3. Applying the method to the setting: results = setting.apply(method, config=config) print(results.summary()) print(f"objective: {results.objective}") return results
def demo_command_line(): """ Run the same demo as above, but customizing the Setting and Method from the command-line. NOTE: Remember to uncomment the function call below to use this instead of demo_simple! """ ## Create the `Setting` and the `Config` from the command-line, like in ## the other examples. parser = ArgumentParser(description=__doc__) ## Add command-line arguments for any Setting in the tree: from sequoia.settings import TaskIncrementalRLSetting, TaskIncrementalSLSetting # parser.add_arguments(TaskIncrementalSLSetting, dest="setting") parser.add_arguments(TaskIncrementalRLSetting, dest="setting") parser.add_arguments(Config, dest="config") # Add the command-line arguments for our CustomMethod (including the # arguments for our simple regularization aux task). CustomMethod.add_argparse_args(parser, dest="method") args = parser.parse_args() setting: ClassIncrementalSetting = args.setting config: Config = args.config # Create the BaselineMethod: base_method = BaselineMethod.from_argparse_args(args, dest="method") # Get the results of the BaselineMethod: base_results = setting.apply(base_method, config=config) ## Create the CustomMethod: new_method = CustomMethod.from_argparse_args(args, dest="method") # Get the results for the CustomMethod: new_results = setting.apply(new_method, config=config) print(f"\n\nComparison: BaselineMethod vs CustomMethod:") print(base_results.summary()) print(new_results.summary())
def test_cmd_false_doesnt_create_conflicts(): @dataclass class A: batch_size: int = field(default=10, cmd=False) @dataclass class B: batch_size: int = 20 # @dataclass # class Foo(TestSetup): # a: A = mutable_field(A) # b: B = mutable_field(B) parser = ArgumentParser(conflict_resolution=ConflictResolution.NONE) parser.add_arguments(A, "a") parser.add_arguments(B, "b") args = parser.parse_args("--batch_size 32".split()) a: A = args.a b: B = args.b assert a == A() assert b == B(batch_size=32)
def update_settings(opts: dataclass, argv: List[str] = None): """ Update given settings from command line arguments. Uses `argparse`, `argcomplete` and `simple_parsing` under the hood. """ if not is_dataclass(opts): raise ValueError('Cannot update args on non-dataclass class') # Use default system argv if not supplied. argv = sys.argv[1:] if argv is None else argv # Update from config file, if applicable. parser_parents = [] if isinstance(opts, Serializable): opts, argv, parser_parents = _update_settings_from_file(opts, argv) parser = ArgumentParser(parents=parser_parents) parser.add_arguments(type(opts), dest='opts', default=opts) if use_argcomplete: argcomplete.autocomplete(parser) args = parser.parse_args(argv) return args.opts
def demo_command_line(): """ Run this quick demo from the command-line. """ parser = ArgumentParser(description=__doc__) # Add command-line arguments for the Method and the Setting. DemoMethod.add_argparse_args(parser) # Add command-line arguments for the Setting and the Config (an object with # options like log_dir, debug, etc, which are not part of the Setting or the # Method) using simple-parsing. parser.add_arguments(DomainIncrementalSLSetting, "setting") parser.add_arguments(Config, "config") args = parser.parse_args() # Create the Method from the parsed arguments method: DemoMethod = DemoMethod.from_argparse_args(args) # Extract the Setting and Config from the args. setting: DomainIncrementalSLSetting = args.setting config: Config = args.config # Run the demo, applying that DemoMethod on the given setting. results: Results = setting.apply(method, config=config) print(results.summary()) print(f"objective: {results.objective}")
def demo(): """ Runs the EwcMethod on a simple setting, just to check that it works fine. """ # Adding arguments for each group directly: parser = ArgumentParser(description=__doc__) EwcMethod.add_argparse_args(parser, dest="method") parser.add_arguments(Config, "config") args = parser.parse_args() method = EwcMethod.from_argparse_args(args, dest="method") config: Config = args.config task_schedule = { 0: { "gravity": 10, "length": 0.2 }, 1000: { "gravity": 100, "length": 1.2 }, # 2000: {"gravity": 10, "length": 0.2}, } setting = TaskIncrementalRLSetting( dataset="cartpole", train_task_schedule=task_schedule, test_task_schedule=task_schedule, observe_state_directly=True, # max_steps=1000, ) # from sequoia.settings import TaskIncrementalSetting, ClassIncrementalSetting # setting = ClassIncrementalSetting(dataset="mnist", nb_tasks=5) # setting = TaskIncrementalSetting(dataset="mnist", nb_tasks=5) results = setting.apply(method, config=config) print(results.summary())
def get_args(): parser = ArgumentParser() parser.add_argument('--runner', default='test', choices=runner_registry.keys(), help="Specify whether to train or test an agent") parser.add_argument('--environment', default="car", choices=environment_registry.keys(), help="Environment to train/test agent on") parser.add_argument('--agent', default='scp', choices=agent_registry.keys(), help='Agent to train/test') args, _ = parser.parse_known_args() runner_class = runner_registry[args.runner] agent_class = agent_registry[args.agent] environment_class = environment_registry[args.environment] parser = agent_class.add_argparse_args(parser) parser = environment_class.add_argparse_args(parser) parser = runner_class.add_argparse_args(parser) return parser.parse_args()
the name of the class in the --help group for this set of parameters. """ attribute1: float = 1.0 """docstring below, When used, this always shows up in the --help text for this attribute""" # Comment above only: this shows up in the help text, since there is no docstring below. attribute2: float = 1.0 attribute3: float = 1.0 # inline comment only (this shows up in the help text, since none of the two other options are present.) # comment above 42 attribute4: float = 1.0 # inline comment """docstring below (this appears in --help)""" # comment above (this appears in --help) 46 attribute5: float = 1.0 # inline comment attribute6: float = 1.0 # inline comment (this appears in --help) attribute7: float = 1.0 # inline comment """docstring below (this appears in --help)""" parser.add_arguments(DocStringsExample, "example") args = parser.parse_args() ex = args.example print(ex) expected = """ DocStringsExample(attribute1=1.0, attribute2=1.0, attribute3=1.0, attribute4=1.0, attribute5=1.0, attribute6=1.0, attribute7=1.0) """
def parse(cls, args: str = ""): """Removes some boilerplate code from the examples.""" parser = ArgumentParser() # Create an argument parser parser.add_arguments(cls, "example") # add arguments for the dataclass ns = parser.parse_args(args.split()) # parse the given `args` return ns.example # return the dataclass instance
@dataclass class CNNStack(): name: str = "stack" num_layers: int = 3 kernel_sizes: Tuple[int, int, int] = (7, 5, 5) num_filters: List[int] = field(default_factory=[32, 64, 64].copy) parser = ArgumentParser(conflict_resolution=ConflictResolution.ALWAYS_MERGE) num_stacks = 3 for i in range(num_stacks): parser.add_arguments(CNNStack, dest=f"stack_{i}", default=CNNStack()) args = parser.parse_args() stack_0 = args.stack_0 stack_1 = args.stack_1 stack_2 = args.stack_2 # BUG: When the list length and the number of instances to parse is the same, # AND there is no default value passed to `add_arguments`, it gets parsed as # multiple lists each with only one element, rather than duplicating the field's # default value correctly. print(stack_0, stack_1, stack_2, sep="\n") expected = """\ CNNStack(name='stack', num_layers=3, kernel_sizes=(7, 5, 5), num_filters=[32, 64, 64]) CNNStack(name='stack', num_layers=3, kernel_sizes=(7, 5, 5), num_filters=[32, 64, 64]) CNNStack(name='stack', num_layers=3, kernel_sizes=(7, 5, 5), num_filters=[32, 64, 64]) """
def wrapper(): parser = ArgumentParser() parser.add_arguments(cls, dest='opts', default=instance) argcomplete.autocomplete(parser) args = parser.parse_args(sys.argv[1:] if argv is None else argv) return main(args.opts)