예제 #1
0
 def add_argparse_args(cls,
                       parser: ArgumentParser,
                       dest: str = None) -> None:
     """Add the command-line arguments for this Method to the given parser.
     
     Override this if you don't use simple-parsing to add the args.
     
     Parameters
     ----------
     parser : ArgumentParser
         The ArgumentParser. 
     dest : str, optional
         The 'base' destination where the arguments should be set on the
         namespace, by default None, in which case the arguments can be at
         the "root" level on the namespace.
     """
     if is_dataclass(cls):
         dest = dest or camel_case(cls.__qualname__)
         parser.add_arguments(cls, dest=dest)
     elif issubclass(cls, LightningDataModule):
         # TODO: Test this case out (using a LightningDataModule as a Setting).
         super().add_argparse_args(parser)  # type: ignore
     else:
         raise NotImplementedError(
             f"Don't know how to add command-line arguments for class "
             f"{cls}, since it isn't a dataclass and doesn't override the "
             f"`add_argparse_args` method!\n"
             f"Either make class {cls} a dataclass and add command-line "
             f"arguments as fields, or add an implementation for the "
             f"`add_argparse_args` and `from_argparse_args` classmethods.")
예제 #2
0
def main_sl():
    """ Applies the PnnMethod in a SL Setting. """
    parser = ArgumentParser(description=__doc__,
                            add_dest_to_option_strings=False)

    # Add arguments for the Setting
    # TODO: PNN is coded for the DomainIncrementalSetting, where the action space
    # is the same for each task.
    # parser.add_arguments(DomainIncrementalSetting, dest="setting")
    parser.add_arguments(TaskIncrementalSLSetting, dest="setting")
    # TaskIncrementalSLSetting.add_argparse_args(parser, dest="setting")
    Config.add_argparse_args(parser, dest="config")

    # Add arguments for the Method:
    PnnMethod.add_argparse_args(parser, dest="method")

    args = parser.parse_args()

    # setting: TaskIncrementalSLSetting = args.setting
    setting: TaskIncrementalSLSetting = TaskIncrementalSLSetting.from_argparse_args(
        # setting: DomainIncrementalSetting = DomainIncrementalSetting.from_argparse_args(
        args,
        dest="setting",
    )
    config: Config = Config.from_argparse_args(args, dest="config")

    method: PnnMethod = PnnMethod.from_argparse_args(args, dest="method")

    method.config = config

    results = setting.apply(method, config=config)
    print(results.summary())
    return results
예제 #3
0
def parse_setting_and_method_instances(
    setting: Union[Setting, Type[Setting]],
    method: Union[Method, Type[Method]],
    argv: Union[str, List[str]] = None,
    strict_args: bool = False,
) -> Tuple[Setting, Method]:
    # TODO: Should we raise an error if an argument appears both in the Setting
    # and the Method?
    parser = ArgumentParser(description=__doc__, add_dest_to_option_strings=False)

    if not isinstance(setting, Setting):
        assert issubclass(setting, Setting)
        setting.add_argparse_args(parser)
    if not isinstance(method, Method):
        assert method is not None
        assert issubclass(method, Method)
        method.add_argparse_args(parser)

    if strict_args:
        args = parser.parse_args(argv)
    else:
        args, unused_args = parser.parse_known_args(argv)
        if unused_args:
            logger.warning(UserWarning(f"Unused command-line args: {unused_args}"))

    if not isinstance(setting, Setting):
        setting = setting.from_argparse_args(args)
    if not isinstance(method, Method):
        method = method.from_argparse_args(args)

    return setting, method
예제 #4
0
파일: ppo_v2.py 프로젝트: yycho0108/PPO
def main():
    parser = ArgumentParser()
    parser.add_arguments(Settings, dest='opts')
    argcomplete.autocomplete(parser)
    opts = parser.parse_args().opts
    if opts.train:
        train(opts)
    else:
        test(opts)
예제 #5
0
def test_passing_instance():
    @dataclass
    class Foo:
        a: int = 123

    parser = ArgumentParser()
    parser.add_arguments(Foo(456), dest="foo")
    args = parser.parse_args("")
    assert args.foo.a == 456, vars(args)
예제 #6
0
 def add_model_specific_args(cls, arg_parser: ArgumentParser):
     cls.Model.add_model_specific_args(arg_parser)
     # Remove the '--env' argument:
     from argparse import Action
     action: Action
     env_action_index = [
         i for i, action in enumerate(arg_parser._actions)
         if "--env" in action.option_strings
     ][0]
     arg_parser._handle_conflict_resolve(
         None, [("--env", arg_parser._actions[env_action_index])])
예제 #7
0
 def add_argparse_args(cls, parser: ArgumentParser, dest: str = ""):
     """Adds command-line arguments for this Method to an argument parser.
     
     NOTE: This doesn't do anything differently than the base implementation,
     but it's included here just for illustration purposes.
     """
     # 'dest' is where the arguments will be stored on the namespace.
     dest = dest or camel_case(cls.__qualname__)
     # Add all command-line arguments. This adds arguments for all fields of
     # this dataclass.
     parser.add_arguments(cls, dest=dest)
예제 #8
0
def test_issue_29():
    from simple_parsing import ArgumentParser

    @dataclass
    class MyCli:
        asdf: Tuple[str, ...]

    parser = ArgumentParser()
    parser.add_arguments(MyCli, dest="args")
    args = parser.parse_args("--asdf asdf fgfh".split())
    assert args.args == MyCli(asdf=("asdf", "fgfh"))
예제 #9
0
def test_issue62():
    import enum
    from simple_parsing.helpers import list_field
    from typing import List
    parser = ArgumentParser()

    class Color(enum.Enum):
        RED = "red"
        ORANGE = "orange"
        BLUE = "blue"

    class Temperature(enum.Enum):
        HOT = 1
        WARM = 0
        COLD = -1
        MONTREAL = -35

    @dataclass
    class MyPreferences(TestSetup):
        """You can use Enums"""

        color: Color = Color.BLUE  # my favorite colour
        # a list of colors
        color_list: List[Color] = list_field(Color.ORANGE)
        # Some floats.
        floats: List[float] = list_field(1.1, 2.2, 3.3)
        # pick a temperature
        temp: Temperature = Temperature.WARM
        # a list of temperatures
        temp_list: List[Temperature] = list_field(Temperature.COLD, Temperature.WARM)

    parser.add_arguments(MyPreferences, "my_preferences")
    assert MyPreferences.setup(
        "--color ORANGE --color_list RED BLUE --temp MONTREAL"
    ) == MyPreferences(
        color=Color.ORANGE,
        color_list=[Color.RED, Color.BLUE],
        temp=Temperature.MONTREAL,
        temp_list=[Temperature.COLD, Temperature.WARM],
    )
    assert MyPreferences.setup(
        "--color ORANGE --color_list RED BLUE --temp MONTREAL --temp_list MONTREAL HOT"
    ) == MyPreferences(
        color=Color.ORANGE,
        color_list=[Color.RED, Color.BLUE],
        temp=Temperature.MONTREAL,
        temp_list=[Temperature.MONTREAL, Temperature.HOT],
    )
    assert Temperature["MONTREAL"] is Temperature.MONTREAL
    assert Temperature(-35) is Temperature.MONTREAL
예제 #10
0
def to_simple_parsing_args(some_dataclass_type):
    """
        Add this as a classmethod to some dataclass in order to make its arguments accessible from commandline
        Example:

        @classmethod
        def get_args(cls):
            args: cls = to_simple_parsing_args(cls)
            return args

    """
    from simple_parsing import ArgumentParser, ConflictResolution
    parser = ArgumentParser(conflict_resolution=ConflictResolution.NONE)
    parser.add_arguments(some_dataclass_type, dest='cfg')
    args: some_dataclass_type = parser.parse_args().cfg
    return args
예제 #11
0
def main():
    parser = ArgumentParser()

    parser.add_arguments(Options, dest="options")

    # Equivalent to:
    # subparsers = parser.add_subparsers(title="config", required=False)
    # parser.set_defaults(config=AConfig())
    # a_parser = subparsers.add_parser("a", help="A help.")
    # a_parser.add_arguments(AConfig, dest="config")
    # b_parser = subparsers.add_parser("b", help="B help.")
    # b_parser.add_arguments(BConfig, dest="config")

    args = parser.parse_args()

    print(args)
    options: Options = args.options
    print(options)
예제 #12
0
    def from_known_args(cls,
                        argv: Union[str, List[str]] = None,
                        reorder: bool = True,
                        strict: bool = False) -> Tuple[P, List[str]]:
        # if not is_dataclass(cls):
        #     raise NotImplementedError(
        #         f"Don't know how to parse an instance of class {cls} from the "
        #         f"command-line, as it isn't a dataclass or doesn't have the "
        #         f"`add_arpargse_args` and `from_argparse_args` classmethods. "
        #         f"You'll have to override the `from_known_args` classmethod."
        #     )

        if argv is None:
            argv = sys.argv[1:]
        logger.debug(f"parsing an instance of class {cls} from argv {argv}")
        if isinstance(argv, str):
            argv = shlex.split(argv)

        parser = ArgumentParser(description=cls.__doc__,
                                add_dest_to_option_strings=False)
        cls.add_argparse_args(parser)
        # TODO: Set temporarily on the class, so its accessible in the class constructor
        cls_argv = cls._argv
        cls._argv = argv

        instance: P
        if strict:
            args = parser.parse_args(argv)
            unused_args = []
        else:
            args, unused_args = parser.parse_known_args(
                argv, attempt_to_reorder=reorder)
            if unused_args:
                logger.debug(
                    RuntimeWarning(
                        f"Unknown/unused args when parsing class {cls}: {unused_args}"
                    ))
        instance = cls.from_argparse_args(args)
        # Save the argv that were used to create the instance on its `_argv`
        # attribute.
        instance._argv = argv
        cls._argv = cls_argv
        return instance, unused_args
def parse_hparams_and_config() -> Tuple[HParams, MnistConfig]:
    # Training settings
    parser = ArgumentParser(description='PyTorch MNIST Example')
    parser.add_arguments(HParams, "hparams", default=HParams())
    parser.add_arguments(MnistConfig, "config", default=MnistConfig())
    args, _ = parser.parse_known_args()
    hparams: HParams = args.hparams
    config: MnistConfig = args.config
    return hparams, config
예제 #14
0
def main_rl():
    """ Applies the PnnMethod in a RL Setting. """
    parser = ArgumentParser(description=__doc__,
                            add_dest_to_option_strings=False)

    Config.add_argparse_args(parser, dest="config")
    PnnMethod.add_argparse_args(parser, dest="method")

    # Haven't tested with observe_state_directly=False
    # it run but I don't know if it converge
    setting = TaskIncrementalRLSetting(
        dataset="cartpole",
        observe_state_directly=True,
        nb_tasks=2,
        train_task_schedule={
            0: {
                "gravity": 10,
                "length": 0.3
            },
            1000: {
                "gravity": 10,
                "length": 0.5
            },
        },
    )

    args = parser.parse_args()

    config: Config = Config.from_argparse_args(args, dest="config")
    method: PnnMethod = PnnMethod.from_argparse_args(args, dest="method")
    method.config = config

    # 2. Creating the Method
    # method = ImproveMethod()

    # 3. Applying the method to the setting:
    results = setting.apply(method, config=config)

    print(results.summary())
    print(f"objective: {results.objective}")
    return results
예제 #15
0
def test_arg_and_dataclass_with_same_name(silent):
    @dataclass
    class SomeClass:
        a: int = 1  # some docstring for attribute 'a'

    parser = ArgumentParser()
    parser.add_argument("--a", default=123)
    with raises(argparse.ArgumentError):
        parser.add_arguments(SomeClass, dest="some_class")
        args = parser.parse_args("")
예제 #16
0
def test_cmd_false_doesnt_create_conflicts():
    @dataclass
    class A:
        batch_size: int = field(default=10, cmd=False)

    @dataclass
    class B:
        batch_size: int = 20

    # @dataclass
    # class Foo(TestSetup):
    #     a: A = mutable_field(A)
    #     b: B = mutable_field(B)

    parser = ArgumentParser(conflict_resolution=ConflictResolution.NONE)
    parser.add_arguments(A, "a")
    parser.add_arguments(B, "b")
    args = parser.parse_args("--batch_size 32".split())
    a: A = args.a
    b: B = args.b
    assert a == A()
    assert b == B(batch_size=32)
예제 #17
0
def update_settings(opts: dataclass, argv: List[str] = None):
    """
    Update given settings from command line arguments.
    Uses `argparse`, `argcomplete` and `simple_parsing` under the hood.
    """
    if not is_dataclass(opts):
        raise ValueError('Cannot update args on non-dataclass class')

    # Use default system argv if not supplied.
    argv = sys.argv[1:] if argv is None else argv

    # Update from config file, if applicable.
    parser_parents = []
    if isinstance(opts, Serializable):
        opts, argv, parser_parents = _update_settings_from_file(opts, argv)

    parser = ArgumentParser(parents=parser_parents)
    parser.add_arguments(type(opts), dest='opts', default=opts)
    if use_argcomplete:
        argcomplete.autocomplete(parser)
    args = parser.parse_args(argv)
    return args.opts
예제 #18
0
def test_vanilla_argparse_issue64():
    """This test shows that the ArgumentDefaultsHelpFormatter of argparse doesn't add
    the "(default: xyz)" if the 'help' argument isn't already passed!

    This begs the question: Should simple-parsing add a 'help' argument always, so that
    the formatter can then add the default string after?
    """
    import argparse

    parser = ArgumentParser(
        "issue64", formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    group = parser.add_argument_group("Options ['options']",
                                      description="These are the options")
    group.add_argument("--foo",
                       type=str,
                       metavar="str",
                       default="aaa",
                       help="Description")
    group.add_argument("--bar", type=str, metavar="str", default="bbb")

    from io import StringIO

    s = StringIO()
    parser.print_help(file=s)
    s.seek(0)

    assert s.read() == textwrap.dedent("""\
    usage: issue64 [-h] [--foo str] [--bar str]

    optional arguments:
      -h, --help  show this help message and exit

    Options ['options']:
      These are the options

      --foo str   Description (default: aaa)
      --bar str
    """)
예제 #19
0
def demo():
    """ Runs the EwcMethod on a simple setting, just to check that it works fine.
    """

    # Adding arguments for each group directly:
    parser = ArgumentParser(description=__doc__)

    EwcMethod.add_argparse_args(parser, dest="method")
    parser.add_arguments(Config, "config")

    args = parser.parse_args()

    method = EwcMethod.from_argparse_args(args, dest="method")
    config: Config = args.config
    task_schedule = {
        0: {
            "gravity": 10,
            "length": 0.2
        },
        1000: {
            "gravity": 100,
            "length": 1.2
        },
        # 2000:   {"gravity": 10, "length": 0.2},
    }
    setting = TaskIncrementalRLSetting(
        dataset="cartpole",
        train_task_schedule=task_schedule,
        test_task_schedule=task_schedule,
        observe_state_directly=True,
        # max_steps=1000,
    )

    # from sequoia.settings import TaskIncrementalSetting, ClassIncrementalSetting
    # setting = ClassIncrementalSetting(dataset="mnist", nb_tasks=5)
    # setting = TaskIncrementalSetting(dataset="mnist", nb_tasks=5)
    results = setting.apply(method, config=config)
    print(results.summary())
예제 #20
0
def test_multiple_at_same_dest_throws_error():
    @dataclass
    class SomeClass:
        a: int = 123

    parser = ArgumentParser()
    parser.add_arguments(SomeClass, "some_class")
    with raises(argparse.ArgumentError):
        parser.add_arguments(SomeClass, "some_class")
예제 #21
0
def baseline_demo_command_line():
    parser = ArgumentParser(__doc__, add_dest_to_option_strings=False)
    
    # Supervised Learning Setting:
    parser.add_arguments(TaskIncrementalSetting, dest="setting")
    # Reinforcement Learning Setting:
    parser.add_arguments(TaskIncrementalRLSetting, dest="setting")

    parser.add_arguments(Config, dest="config")
    BaselineMethod.add_argparse_args(parser, dest="method")

    args = parser.parse_args()

    setting: Setting = args.setting
    config: Config = args.config
    method: BaselineMethod = BaselineMethod.from_argparse_args(args, dest="method")
    
    results = setting.apply(method, config=config)
    print(results.summary())
    return results
예제 #22
0
파일: main.py 프로젝트: dtch1997/neural-car
def get_args():
    parser = ArgumentParser()
    parser.add_argument('--runner',
                        default='test',
                        choices=runner_registry.keys(),
                        help="Specify whether to train or test an agent")
    parser.add_argument('--environment',
                        default="car",
                        choices=environment_registry.keys(),
                        help="Environment to train/test agent on")
    parser.add_argument('--agent',
                        default='scp',
                        choices=agent_registry.keys(),
                        help='Agent to train/test')

    args, _ = parser.parse_known_args()
    runner_class = runner_registry[args.runner]
    agent_class = agent_registry[args.agent]
    environment_class = environment_registry[args.environment]

    parser = agent_class.add_argparse_args(parser)
    parser = environment_class.add_argparse_args(parser)
    parser = runner_class.add_argparse_args(parser)
    return parser.parse_args()
예제 #23
0
    def test_vanilla_argparse_beheviour(self, options: Dict[str, Any],
                                        passed_arg: str, expected_value: Any):
        parser = ArgumentParser()
        parser.add_argument("--foo", **options)

        if passed_arg is DONT_PASS:
            args = parser.parse_args("")
        else:
            args = parser.parse_args(shlex.split("--foo " + passed_arg))
        foo = args.foo
        assert foo == expected_value
예제 #24
0
def test_store_false_action():
    parser = ArgumentParser(add_option_string_dash_variants=True)
    parser.add_arguments(Foo, "foo")

    args = parser.parse_args("--no-cache".split())
    foo: Foo = args.foo
    assert foo.no_cache == False

    args = parser.parse_args("".split())
    foo: Foo = args.foo
    assert foo.no_cache == True
예제 #25
0
def test_parsing_twice():
    @dataclass
    class Foo:
        a: int = 123

    parser = ArgumentParser()
    parser.add_arguments(Foo, dest="foo")
    args = parser.parse_args("")
    assert args.foo.a == 123, vars(args)
    args = parser.parse_args("--a 456".split())
    assert args.foo.a == 456, vars(args)
예제 #26
0
def test_works_fine_with_other_argparse_arguments(simple_attribute, silent):
    some_type, passed_value, expected_value = simple_attribute

    @dataclass
    class SomeClass:
        a: some_type  # type: ignore
        """some docstring for attribute 'a'"""

    parser = ArgumentParser()
    parser.add_argument("--x", type=int)
    parser.add_arguments(SomeClass, dest="some_class")

    x = 123
    args = parser.parse_args(shlex.split(f"--x {x} --a {passed_value}"))
    assert args == argparse.Namespace(some_class=SomeClass(a=expected_value), x=x)
예제 #27
0
def test_arg_and_dataclass_with_same_name_after_prefixing(silent):
    @dataclass
    class SomeClass:
        a: int = 1  # some docstring for attribute 'a'

    @dataclass
    class Parent:
        pre: SomeClass = SomeClass()
        bla: SomeClass = SomeClass()

    parser = ArgumentParser()
    parser.add_argument("--pre.a", default=123, type=int)
    with raises(argparse.ArgumentError):
        parser.add_arguments(Parent, dest="some_class")
        args = parser.parse_args("--pre.a 123 --pre.a 456".split())
예제 #28
0
    def test_reproduce(self):
        parser = ArgumentParser()
        parser.add_arguments(MyConfig, dest="cfg")
        args_none, _ = parser.parse_known_args([])
        args, extra = parser.parse_known_args(["--values", "3", "4"])
        values = args.cfg.values
        # This is what we'd expect:
        # assert values == (3, 4)
        # assert extra == []

        # But instead we get this:
        assert values == (3)
        assert extra == ["4"]
예제 #29
0
    def test_optional_list(self, passed_arg: str, expected_value: Any):
        parser = ArgumentParser()

        @dataclass
        class MyConfig:
            foo: Optional[List[str]] = None

        parser.add_arguments(MyConfig, dest="config")

        if passed_arg is DONT_PASS:
            args = parser.parse_args("")
        else:
            args = parser.parse_args(shlex.split("--foo " + passed_arg))
        assert args.config.foo == expected_value
예제 #30
0
def test_experiments():
    from abc import ABC

    @dataclass
    class Experiment(ABC):
        dataset: str
        iid: bool = True

    @dataclass
    class Mnist(Experiment):
        dataset: str = "mnist"
        iid: bool = True

    @dataclass
    class MnistContinual(Experiment):
        dataset: str = "mnist"
        iid: bool = False

    @dataclass
    class Config:
        experiment: Experiment = subparsers({
            "mnist": Mnist,
            "mnist_continual": MnistContinual,
        })

    for field in dataclasses.fields(Config):
        assert simple_parsing.utils.is_subparser_field(field), field

    parser = ArgumentParser()
    parser.add_arguments(Config, "config")

    with raises(SystemExit):
        args = parser.parse_args()

    args = parser.parse_args("mnist".split())
    experiment = args.config.experiment
    assert isinstance(experiment, Mnist)
    assert experiment.dataset == "mnist"
    assert experiment.iid == True

    args = parser.parse_args("mnist_continual".split())
    experiment = args.config.experiment
    assert isinstance(experiment, MnistContinual)
    assert experiment.dataset == "mnist"
    assert experiment.iid == False