Exemplo n.º 1
0
def test_arg_and_dataclass_with_same_name(silent):
    @dataclass
    class SomeClass:
        a: int = 1  # some docstring for attribute 'a'

    parser = ArgumentParser()
    parser.add_argument("--a", default=123)
    with raises(argparse.ArgumentError):
        parser.add_arguments(SomeClass, dest="some_class")
        args = parser.parse_args("")
Exemplo n.º 2
0
    def test_vanilla_argparse_beheviour(self, options: Dict[str, Any],
                                        passed_arg: str, expected_value: Any):
        parser = ArgumentParser()
        parser.add_argument("--foo", **options)

        if passed_arg is DONT_PASS:
            args = parser.parse_args("")
        else:
            args = parser.parse_args(shlex.split("--foo " + passed_arg))
        foo = args.foo
        assert foo == expected_value
Exemplo n.º 3
0
def test_works_fine_with_other_argparse_arguments(simple_attribute, silent):
    some_type, passed_value, expected_value = simple_attribute

    @dataclass
    class SomeClass:
        a: some_type  # type: ignore
        """some docstring for attribute 'a'"""

    parser = ArgumentParser()
    parser.add_argument("--x", type=int)
    parser.add_arguments(SomeClass, dest="some_class")

    x = 123
    args = parser.parse_args(shlex.split(f"--x {x} --a {passed_value}"))
    assert args == argparse.Namespace(some_class=SomeClass(a=expected_value), x=x)
Exemplo n.º 4
0
def test_arg_and_dataclass_with_same_name_after_prefixing(silent):
    @dataclass
    class SomeClass:
        a: int = 1  # some docstring for attribute 'a'

    @dataclass
    class Parent:
        pre: SomeClass = SomeClass()
        bla: SomeClass = SomeClass()

    parser = ArgumentParser()
    parser.add_argument("--pre.a", default=123, type=int)
    with raises(argparse.ArgumentError):
        parser.add_arguments(Parent, dest="some_class")
        args = parser.parse_args("--pre.a 123 --pre.a 456".split())
Exemplo n.º 5
0
def get_args():
    parser = ArgumentParser()
    parser.add_argument('--runner',
                        default='test',
                        choices=runner_registry.keys(),
                        help="Specify whether to train or test an agent")
    parser.add_argument('--environment',
                        default="car",
                        choices=environment_registry.keys(),
                        help="Environment to train/test agent on")
    parser.add_argument('--agent',
                        default='scp',
                        choices=agent_registry.keys(),
                        help='Agent to train/test')

    args, _ = parser.parse_known_args()
    runner_class = runner_registry[args.runner]
    agent_class = agent_registry[args.agent]
    environment_class = environment_registry[args.environment]

    parser = agent_class.add_argparse_args(parser)
    parser = environment_class.add_argparse_args(parser)
    parser = runner_class.add_argparse_args(parser)
    return parser.parse_args()
Exemplo n.º 6
0
# examples/demo.py
from dataclasses import dataclass
from simple_parsing import ArgumentParser

parser = ArgumentParser()
parser.add_argument("--foo", type=int, default=123, help="foo help")


@dataclass
class Options:
    """ Help string for this group of command-line arguments """
    log_dir: str  # Help string for a required str argument
    learning_rate: float = 1e-4  # Help string for a float argument


parser.add_arguments(Options, dest="options")

args = parser.parse_args()
print("foo:", args.foo)
print("options:", args.options)
@dataclass
class WGANHyperParameters(GANHyperParameters):
    lambda_coeff: float = 10  # the lambda penalty coefficient.


@dataclass
class WGANGPHyperParameters(WGANHyperParameters):
    gp_penalty: float = 1e-6  # Gradient penalty coefficient


parser = ArgumentParser()
parser.add_argument(
    "--load_path",
    type=str,
    default=None,
    help=
    "If given, the HyperParameters are read from the given file instead of from the command-line."
)
parser.add_arguments(WGANGPHyperParameters, dest="hparams")

args = parser.parse_args()

load_path: str = args.load_path
if load_path is None:
    hparams: WGANGPHyperParameters = args.hparams
else:
    hparams = WGANGPHyperParameters.load_json(load_path)

assert hparams == WGANGPHyperParameters(batch_size=32,
                                        d_steps=1,