def test_add_argparse_args(cls, name):
    """
    Tests that ``add_argparse_args`` handles argument groups correctly, and
    can be parsed.
    """
    parser = ArgumentParser()
    parser_main = parser.add_argument_group("main")
    parser_main.add_argument("--main_arg", type=str, default="")
    parser_old = parser  # For testing.
    parser = add_argparse_args(cls, parser)
    assert parser is parser_old

    # Check nominal argument groups.
    help_text = extract_help_text(parser)
    assert "main:" in help_text
    assert "--main_arg" in help_text
    assert f"{name}:" in help_text
    assert "--my_parameter" in help_text
    if cls is not AddArgparseArgsExampleClassNoDoc:
        assert "A thing" in help_text

    fake_argv = ["--main_arg=abc", "--my_parameter=2"]
    args = parser.parse_args(fake_argv)
    assert args.main_arg == "abc"
    assert args.my_parameter == 2
Пример #2
0
    def add_argparse_args(cls, parent_parser: ArgumentParser,
                          **kwargs) -> ArgumentParser:
        """Extends existing argparse by default `LightningDataModule` attributes.

        Example::

            parser = ArgumentParser(add_help=False)
            parser = LightningDataModule.add_argparse_args(parser)
        """
        return add_argparse_args(cls, parent_parser, **kwargs)
Пример #3
0
def test_add_argparse_args_no_argument_group():
    """Tests that ``add_argparse_args(..., use_argument_group=False)`` (old workflow) handles argument groups
    correctly, and can be parsed."""
    parser = ArgumentParser()
    parser.add_argument("--main_arg", type=str, default="")
    parser_old = parser  # For testing.
    parser = add_argparse_args(AddArgparseArgsExampleClass, parser, use_argument_group=False)
    assert parser is not parser_old

    # Check arguments.
    help_text = extract_help_text(parser)
    assert "--main_arg" in help_text
    assert "--my_parameter" in help_text
    assert "AddArgparseArgsExampleClass:" not in help_text

    fake_argv = ["--main_arg=abc", "--my_parameter=2"]
    args = parser.parse_args(fake_argv)
    assert args.main_arg == "abc"
    assert args.my_parameter == 2
Пример #4
0
 def add_argparse_args(cls, parent_parser: ArgumentParser,
                       **kwargs) -> ArgumentParser:
     return add_argparse_args(cls, parent_parser, **kwargs)
Пример #5
0
 def add_argparse_args(cls, *args, **kwargs) -> ArgumentParser:
     # the lightning trainer implementation does not support subclasses.
     # context: https://github.com/PyTorchLightning/lightning-flash/issues/342#issuecomment-848892447
     return add_argparse_args(PlTrainer, *args, **kwargs)
Пример #6
0
 def add_argparse_args(cls, parent_parser: ArgumentParser,
                       **kwargs) -> ArgumentParser:
     """Extends existing argparse by default `LightningDataModule` attributes."""
     return add_argparse_args(cls, parent_parser, **kwargs)
def test_negative_add_argparse_args():
    with pytest.raises(RuntimeError,
                       match="Please only pass an ArgumentParser instance."):
        parser = ArgumentParser()
        add_argparse_args(AddArgparseArgsExampleClass,
                          parser.add_argument_group("bad workflow"))
Пример #8
0
def test_add_argparse_args_invalid():
    """Test that `add_argparse_args` doesn't raise `TypeError` when a class has args typed as `typing.Generic` in
    Python 3.6."""
    add_argparse_args(AddArgparseArgsExampleClassGeneric, ArgumentParser())