コード例 #1
0
def test_parsing_twice():
    @dataclass
    class Foo:
        a: int = 123

    parser = ArgumentParser()
    parser.add_arguments(Foo, dest="foo")
    args = parser.parse_args("")
    assert args.foo.a == 123, vars(args)
    args = parser.parse_args("--a 456".split())
    assert args.foo.a == 456, vars(args)
コード例 #2
0
ファイル: test_tuples.py プロジェクト: lebrice/SimpleParsing
    def test_vanilla_argparse_beheviour(self, options: Dict[str, Any],
                                        passed_arg: str, expected_value: Any):
        parser = ArgumentParser()
        parser.add_argument("--foo", **options)

        if passed_arg is DONT_PASS:
            args = parser.parse_args("")
        else:
            args = parser.parse_args(shlex.split("--foo " + passed_arg))
        foo = args.foo
        assert foo == expected_value
コード例 #3
0
def test_store_false_action():
    parser = ArgumentParser(add_option_string_dash_variants=True)
    parser.add_arguments(Foo, "foo")

    args = parser.parse_args("--no-cache".split())
    foo: Foo = args.foo
    assert foo.no_cache == False

    args = parser.parse_args("".split())
    foo: Foo = args.foo
    assert foo.no_cache == True
コード例 #4
0
ファイル: test_tuples.py プロジェクト: lebrice/SimpleParsing
    def test_optional_list(self, passed_arg: str, expected_value: Any):
        parser = ArgumentParser()

        @dataclass
        class MyConfig:
            foo: Optional[List[str]] = None

        parser.add_arguments(MyConfig, dest="config")

        if passed_arg is DONT_PASS:
            args = parser.parse_args("")
        else:
            args = parser.parse_args(shlex.split("--foo " + passed_arg))
        assert args.config.foo == expected_value
コード例 #5
0
def test_optional_seed():
    """Test that a value marked as Optional works fine.

    (Reproduces https://github.com/lebrice/SimpleParsing/issues/14#issue-562538623)
    """
    parser = ArgumentParser()
    parser.add_arguments(Config, dest="config")

    args = parser.parse_args("".split())
    config: Config = args.config
    assert config == Config()

    args = parser.parse_args("--seed 123".split())
    config: Config = args.config
    assert config == Config(123)
コード例 #6
0
def main_sl():
    """ Applies the PnnMethod in a SL Setting. """
    parser = ArgumentParser(description=__doc__,
                            add_dest_to_option_strings=False)

    # Add arguments for the Setting
    # TODO: PNN is coded for the DomainIncrementalSetting, where the action space
    # is the same for each task.
    # parser.add_arguments(DomainIncrementalSetting, dest="setting")
    parser.add_arguments(TaskIncrementalSLSetting, dest="setting")
    # TaskIncrementalSLSetting.add_argparse_args(parser, dest="setting")
    Config.add_argparse_args(parser, dest="config")

    # Add arguments for the Method:
    PnnMethod.add_argparse_args(parser, dest="method")

    args = parser.parse_args()

    # setting: TaskIncrementalSLSetting = args.setting
    setting: TaskIncrementalSLSetting = TaskIncrementalSLSetting.from_argparse_args(
        # setting: DomainIncrementalSetting = DomainIncrementalSetting.from_argparse_args(
        args,
        dest="setting",
    )
    config: Config = Config.from_argparse_args(args, dest="config")

    method: PnnMethod = PnnMethod.from_argparse_args(args, dest="method")

    method.config = config

    results = setting.apply(method, config=config)
    print(results.summary())
    return results
コード例 #7
0
def parse_setting_and_method_instances(
    setting: Union[Setting, Type[Setting]],
    method: Union[Method, Type[Method]],
    argv: Union[str, List[str]] = None,
    strict_args: bool = False,
) -> Tuple[Setting, Method]:
    # TODO: Should we raise an error if an argument appears both in the Setting
    # and the Method?
    parser = ArgumentParser(description=__doc__, add_dest_to_option_strings=False)

    if not isinstance(setting, Setting):
        assert issubclass(setting, Setting)
        setting.add_argparse_args(parser)
    if not isinstance(method, Method):
        assert method is not None
        assert issubclass(method, Method)
        method.add_argparse_args(parser)

    if strict_args:
        args = parser.parse_args(argv)
    else:
        args, unused_args = parser.parse_known_args(argv)
        if unused_args:
            logger.warning(UserWarning(f"Unused command-line args: {unused_args}"))

    if not isinstance(setting, Setting):
        setting = setting.from_argparse_args(args)
    if not isinstance(method, Method):
        method = method.from_argparse_args(args)

    return setting, method
コード例 #8
0
def test_experiments():
    from abc import ABC

    @dataclass
    class Experiment(ABC):
        dataset: str
        iid: bool = True

    @dataclass
    class Mnist(Experiment):
        dataset: str = "mnist"
        iid: bool = True

    @dataclass
    class MnistContinual(Experiment):
        dataset: str = "mnist"
        iid: bool = False

    @dataclass
    class Config:
        experiment: Experiment = subparsers({
            "mnist": Mnist,
            "mnist_continual": MnistContinual,
        })

    for field in dataclasses.fields(Config):
        assert simple_parsing.utils.is_subparser_field(field), field

    parser = ArgumentParser()
    parser.add_arguments(Config, "config")

    with raises(SystemExit):
        args = parser.parse_args()

    args = parser.parse_args("mnist".split())
    experiment = args.config.experiment
    assert isinstance(experiment, Mnist)
    assert experiment.dataset == "mnist"
    assert experiment.iid == True

    args = parser.parse_args("mnist_continual".split())
    experiment = args.config.experiment
    assert isinstance(experiment, MnistContinual)
    assert experiment.dataset == "mnist"
    assert experiment.iid == False
コード例 #9
0
ファイル: ppo_v2.py プロジェクト: yycho0108/PPO
def main():
    parser = ArgumentParser()
    parser.add_arguments(Settings, dest='opts')
    argcomplete.autocomplete(parser)
    opts = parser.parse_args().opts
    if opts.train:
        train(opts)
    else:
        test(opts)
コード例 #10
0
def test_passing_instance():
    @dataclass
    class Foo:
        a: int = 123

    parser = ArgumentParser()
    parser.add_arguments(Foo(456), dest="foo")
    args = parser.parse_args("")
    assert args.foo.a == 456, vars(args)
コード例 #11
0
def test_arg_and_dataclass_with_same_name(silent):
    @dataclass
    class SomeClass:
        a: int = 1  # some docstring for attribute 'a'

    parser = ArgumentParser()
    parser.add_argument("--a", default=123)
    with raises(argparse.ArgumentError):
        parser.add_arguments(SomeClass, dest="some_class")
        args = parser.parse_args("")
コード例 #12
0
ファイル: test_tuples.py プロジェクト: lebrice/SimpleParsing
    def test_issue_47_is_fixed(
        self,
        field_type: Type,
        passed_arg: Union[str, object],
        expected_value: Any,
    ):
        parser = ArgumentParser()

        @dataclass
        class MyConfig:
            values: field_type

        parser.add_arguments(MyConfig, dest="config")
        if passed_arg is DONT_PASS:
            args = parser.parse_args([])
        else:
            args = parser.parse_args(shlex.split("--values " + passed_arg))
        actual_values = args.config.values
        assert actual_values == expected_value
コード例 #13
0
ファイル: test_tuples.py プロジェクト: lebrice/SimpleParsing
def test_issue_29():
    from simple_parsing import ArgumentParser

    @dataclass
    class MyCli:
        asdf: Tuple[str, ...]

    parser = ArgumentParser()
    parser.add_arguments(MyCli, dest="args")
    args = parser.parse_args("--asdf asdf fgfh".split())
    assert args.args == MyCli(asdf=("asdf", "fgfh"))
コード例 #14
0
def test_arg_and_dataclass_with_same_name_after_prefixing(silent):
    @dataclass
    class SomeClass:
        a: int = 1  # some docstring for attribute 'a'

    @dataclass
    class Parent:
        pre: SomeClass = SomeClass()
        bla: SomeClass = SomeClass()

    parser = ArgumentParser()
    parser.add_argument("--pre.a", default=123, type=int)
    with raises(argparse.ArgumentError):
        parser.add_arguments(Parent, dest="some_class")
        args = parser.parse_args("--pre.a 123 --pre.a 456".split())
コード例 #15
0
def test_works_fine_with_other_argparse_arguments(simple_attribute, silent):
    some_type, passed_value, expected_value = simple_attribute

    @dataclass
    class SomeClass:
        a: some_type  # type: ignore
        """some docstring for attribute 'a'"""

    parser = ArgumentParser()
    parser.add_argument("--x", type=int)
    parser.add_arguments(SomeClass, dest="some_class")

    x = 123
    args = parser.parse_args(shlex.split(f"--x {x} --a {passed_value}"))
    assert args == argparse.Namespace(some_class=SomeClass(a=expected_value), x=x)
コード例 #16
0
def to_simple_parsing_args(some_dataclass_type):
    """
        Add this as a classmethod to some dataclass in order to make its arguments accessible from commandline
        Example:

        @classmethod
        def get_args(cls):
            args: cls = to_simple_parsing_args(cls)
            return args

    """
    from simple_parsing import ArgumentParser, ConflictResolution
    parser = ArgumentParser(conflict_resolution=ConflictResolution.NONE)
    parser.add_arguments(some_dataclass_type, dest='cfg')
    args: some_dataclass_type = parser.parse_args().cfg
    return args
コード例 #17
0
def main():
    parser = ArgumentParser()

    parser.add_arguments(Options, dest="options")

    # Equivalent to:
    # subparsers = parser.add_subparsers(title="config", required=False)
    # parser.set_defaults(config=AConfig())
    # a_parser = subparsers.add_parser("a", help="A help.")
    # a_parser.add_arguments(AConfig, dest="config")
    # b_parser = subparsers.add_parser("b", help="B help.")
    # b_parser.add_arguments(BConfig, dest="config")

    args = parser.parse_args()

    print(args)
    options: Options = args.options
    print(options)
コード例 #18
0
    def from_known_args(cls,
                        argv: Union[str, List[str]] = None,
                        reorder: bool = True,
                        strict: bool = False) -> Tuple[P, List[str]]:
        # if not is_dataclass(cls):
        #     raise NotImplementedError(
        #         f"Don't know how to parse an instance of class {cls} from the "
        #         f"command-line, as it isn't a dataclass or doesn't have the "
        #         f"`add_arpargse_args` and `from_argparse_args` classmethods. "
        #         f"You'll have to override the `from_known_args` classmethod."
        #     )

        if argv is None:
            argv = sys.argv[1:]
        logger.debug(f"parsing an instance of class {cls} from argv {argv}")
        if isinstance(argv, str):
            argv = shlex.split(argv)

        parser = ArgumentParser(description=cls.__doc__,
                                add_dest_to_option_strings=False)
        cls.add_argparse_args(parser)
        # TODO: Set temporarily on the class, so its accessible in the class constructor
        cls_argv = cls._argv
        cls._argv = argv

        instance: P
        if strict:
            args = parser.parse_args(argv)
            unused_args = []
        else:
            args, unused_args = parser.parse_known_args(
                argv, attempt_to_reorder=reorder)
            if unused_args:
                logger.debug(
                    RuntimeWarning(
                        f"Unknown/unused args when parsing class {cls}: {unused_args}"
                    ))
        instance = cls.from_argparse_args(args)
        # Save the argv that were used to create the instance on its `_argv`
        # attribute.
        instance._argv = argv
        cls._argv = cls_argv
        return instance, unused_args
コード例 #19
0
def baseline_demo_command_line():
    parser = ArgumentParser(__doc__, add_dest_to_option_strings=False)
    
    # Supervised Learning Setting:
    parser.add_arguments(TaskIncrementalSetting, dest="setting")
    # Reinforcement Learning Setting:
    parser.add_arguments(TaskIncrementalRLSetting, dest="setting")

    parser.add_arguments(Config, dest="config")
    BaselineMethod.add_argparse_args(parser, dest="method")

    args = parser.parse_args()

    setting: Setting = args.setting
    config: Config = args.config
    method: BaselineMethod = BaselineMethod.from_argparse_args(args, dest="method")
    
    results = setting.apply(method, config=config)
    print(results.summary())
    return results
コード例 #20
0
ファイル: pnn_method.py プロジェクト: ryanlindeborg/Sequoia
def main_rl():
    """ Applies the PnnMethod in a RL Setting. """
    parser = ArgumentParser(description=__doc__,
                            add_dest_to_option_strings=False)

    Config.add_argparse_args(parser, dest="config")
    PnnMethod.add_argparse_args(parser, dest="method")

    # Haven't tested with observe_state_directly=False
    # it run but I don't know if it converge
    setting = TaskIncrementalRLSetting(
        dataset="cartpole",
        observe_state_directly=True,
        nb_tasks=2,
        train_task_schedule={
            0: {
                "gravity": 10,
                "length": 0.3
            },
            1000: {
                "gravity": 10,
                "length": 0.5
            },
        },
    )

    args = parser.parse_args()

    config: Config = Config.from_argparse_args(args, dest="config")
    method: PnnMethod = PnnMethod.from_argparse_args(args, dest="method")
    method.config = config

    # 2. Creating the Method
    # method = ImproveMethod()

    # 3. Applying the method to the setting:
    results = setting.apply(method, config=config)

    print(results.summary())
    print(f"objective: {results.objective}")
    return results
コード例 #21
0
def demo_command_line():
    """ Run the same demo as above, but customizing the Setting and Method from
    the command-line.
    
    NOTE: Remember to uncomment the function call below to use this instead of
    demo_simple!
    """
    ## Create the `Setting` and the `Config` from the command-line, like in
    ## the other examples.
    parser = ArgumentParser(description=__doc__)

    ## Add command-line arguments for any Setting in the tree:
    from sequoia.settings import TaskIncrementalRLSetting, TaskIncrementalSLSetting

    # parser.add_arguments(TaskIncrementalSLSetting, dest="setting")
    parser.add_arguments(TaskIncrementalRLSetting, dest="setting")
    parser.add_arguments(Config, dest="config")

    # Add the command-line arguments for our CustomMethod (including the
    # arguments for our simple regularization aux task).
    CustomMethod.add_argparse_args(parser, dest="method")

    args = parser.parse_args()

    setting: ClassIncrementalSetting = args.setting
    config: Config = args.config

    # Create the BaselineMethod:
    base_method = BaselineMethod.from_argparse_args(args, dest="method")
    # Get the results of the BaselineMethod:
    base_results = setting.apply(base_method, config=config)

    ## Create the CustomMethod:
    new_method = CustomMethod.from_argparse_args(args, dest="method")
    # Get the results for the CustomMethod:
    new_results = setting.apply(new_method, config=config)

    print(f"\n\nComparison: BaselineMethod vs CustomMethod:")
    print(base_results.summary())
    print(new_results.summary())
コード例 #22
0
ファイル: test_fields.py プロジェクト: lebrice/SimpleParsing
def test_cmd_false_doesnt_create_conflicts():
    @dataclass
    class A:
        batch_size: int = field(default=10, cmd=False)

    @dataclass
    class B:
        batch_size: int = 20

    # @dataclass
    # class Foo(TestSetup):
    #     a: A = mutable_field(A)
    #     b: B = mutable_field(B)

    parser = ArgumentParser(conflict_resolution=ConflictResolution.NONE)
    parser.add_arguments(A, "a")
    parser.add_arguments(B, "b")
    args = parser.parse_args("--batch_size 32".split())
    a: A = args.a
    b: B = args.b
    assert a == A()
    assert b == B(batch_size=32)
コード例 #23
0
def update_settings(opts: dataclass, argv: List[str] = None):
    """
    Update given settings from command line arguments.
    Uses `argparse`, `argcomplete` and `simple_parsing` under the hood.
    """
    if not is_dataclass(opts):
        raise ValueError('Cannot update args on non-dataclass class')

    # Use default system argv if not supplied.
    argv = sys.argv[1:] if argv is None else argv

    # Update from config file, if applicable.
    parser_parents = []
    if isinstance(opts, Serializable):
        opts, argv, parser_parents = _update_settings_from_file(opts, argv)

    parser = ArgumentParser(parents=parser_parents)
    parser.add_arguments(type(opts), dest='opts', default=opts)
    if use_argcomplete:
        argcomplete.autocomplete(parser)
    args = parser.parse_args(argv)
    return args.opts
コード例 #24
0
def demo_command_line():
    """ Run this quick demo from the command-line. """
    parser = ArgumentParser(description=__doc__)
    # Add command-line arguments for the Method and the Setting.
    DemoMethod.add_argparse_args(parser)
    # Add command-line arguments for the Setting and the Config (an object with
    # options like log_dir, debug, etc, which are not part of the Setting or the
    # Method) using simple-parsing.
    parser.add_arguments(DomainIncrementalSLSetting, "setting")
    parser.add_arguments(Config, "config")
    args = parser.parse_args()

    # Create the Method from the parsed arguments
    method: DemoMethod = DemoMethod.from_argparse_args(args)
    # Extract the Setting and Config from the args.
    setting: DomainIncrementalSLSetting = args.setting
    config: Config = args.config

    # Run the demo, applying that DemoMethod on the given setting.
    results: Results = setting.apply(method, config=config)
    print(results.summary())
    print(f"objective: {results.objective}")
コード例 #25
0
def demo():
    """ Runs the EwcMethod on a simple setting, just to check that it works fine.
    """

    # Adding arguments for each group directly:
    parser = ArgumentParser(description=__doc__)

    EwcMethod.add_argparse_args(parser, dest="method")
    parser.add_arguments(Config, "config")

    args = parser.parse_args()

    method = EwcMethod.from_argparse_args(args, dest="method")
    config: Config = args.config
    task_schedule = {
        0: {
            "gravity": 10,
            "length": 0.2
        },
        1000: {
            "gravity": 100,
            "length": 1.2
        },
        # 2000:   {"gravity": 10, "length": 0.2},
    }
    setting = TaskIncrementalRLSetting(
        dataset="cartpole",
        train_task_schedule=task_schedule,
        test_task_schedule=task_schedule,
        observe_state_directly=True,
        # max_steps=1000,
    )

    # from sequoia.settings import TaskIncrementalSetting, ClassIncrementalSetting
    # setting = ClassIncrementalSetting(dataset="mnist", nb_tasks=5)
    # setting = TaskIncrementalSetting(dataset="mnist", nb_tasks=5)
    results = setting.apply(method, config=config)
    print(results.summary())
コード例 #26
0
ファイル: main.py プロジェクト: dtch1997/neural-car
def get_args():
    parser = ArgumentParser()
    parser.add_argument('--runner',
                        default='test',
                        choices=runner_registry.keys(),
                        help="Specify whether to train or test an agent")
    parser.add_argument('--environment',
                        default="car",
                        choices=environment_registry.keys(),
                        help="Environment to train/test agent on")
    parser.add_argument('--agent',
                        default='scp',
                        choices=agent_registry.keys(),
                        help='Agent to train/test')

    args, _ = parser.parse_known_args()
    runner_class = runner_registry[args.runner]
    agent_class = agent_registry[args.agent]
    environment_class = environment_registry[args.environment]

    parser = agent_class.add_argparse_args(parser)
    parser = environment_class.add_argparse_args(parser)
    parser = runner_class.add_argparse_args(parser)
    return parser.parse_args()
コード例 #27
0
    the name of the class in the --help group for this set of parameters.    
    """

    attribute1: float = 1.0
    """docstring below, When used, this always shows up in the --help text for this attribute"""

    # Comment above only: this shows up in the help text, since there is no docstring below.
    attribute2: float = 1.0

    attribute3: float = 1.0  # inline comment only (this shows up in the help text, since none of the two other options are present.)

    # comment above 42
    attribute4: float = 1.0  # inline comment
    """docstring below (this appears in --help)"""

    # comment above (this appears in --help) 46
    attribute5: float = 1.0  # inline comment

    attribute6: float = 1.0  # inline comment (this appears in --help)

    attribute7: float = 1.0  # inline comment
    """docstring below (this appears in --help)"""


parser.add_arguments(DocStringsExample, "example")
args = parser.parse_args()
ex = args.example
print(ex)
expected = """
DocStringsExample(attribute1=1.0, attribute2=1.0, attribute3=1.0, attribute4=1.0, attribute5=1.0, attribute6=1.0, attribute7=1.0)
"""
コード例 #28
0
def parse(cls, args: str = ""):
    """Removes some boilerplate code from the examples."""
    parser = ArgumentParser()  # Create an argument parser
    parser.add_arguments(cls, "example")  # add arguments for the dataclass
    ns = parser.parse_args(args.split())  # parse the given `args`
    return ns.example  # return the dataclass instance
コード例 #29
0
@dataclass
class CNNStack():
    name: str = "stack"
    num_layers: int = 3
    kernel_sizes: Tuple[int, int, int] = (7, 5, 5)
    num_filters: List[int] = field(default_factory=[32, 64, 64].copy)


parser = ArgumentParser(conflict_resolution=ConflictResolution.ALWAYS_MERGE)

num_stacks = 3
for i in range(num_stacks):
    parser.add_arguments(CNNStack, dest=f"stack_{i}", default=CNNStack())

args = parser.parse_args()
stack_0 = args.stack_0
stack_1 = args.stack_1
stack_2 = args.stack_2

# BUG: When the list length and the number of instances to parse is the same,
# AND there is no default value passed to `add_arguments`, it gets parsed as
# multiple lists each with only one element, rather than duplicating the field's
# default value correctly.

print(stack_0, stack_1, stack_2, sep="\n")
expected = """\
CNNStack(name='stack', num_layers=3, kernel_sizes=(7, 5, 5), num_filters=[32, 64, 64])
CNNStack(name='stack', num_layers=3, kernel_sizes=(7, 5, 5), num_filters=[32, 64, 64])
CNNStack(name='stack', num_layers=3, kernel_sizes=(7, 5, 5), num_filters=[32, 64, 64])
"""
コード例 #30
0
 def wrapper():
     parser = ArgumentParser()
     parser.add_arguments(cls, dest='opts', default=instance)
     argcomplete.autocomplete(parser)
     args = parser.parse_args(sys.argv[1:] if argv is None else argv)
     return main(args.opts)