def test_add_argparse_args(cls, name): """ Tests that ``add_argparse_args`` handles argument groups correctly, and can be parsed. """ parser = ArgumentParser() parser_main = parser.add_argument_group("main") parser_main.add_argument("--main_arg", type=str, default="") parser_old = parser # For testing. parser = add_argparse_args(cls, parser) assert parser is parser_old # Check nominal argument groups. help_text = extract_help_text(parser) assert "main:" in help_text assert "--main_arg" in help_text assert f"{name}:" in help_text assert "--my_parameter" in help_text if cls is not AddArgparseArgsExampleClassNoDoc: assert "A thing" in help_text fake_argv = ["--main_arg=abc", "--my_parameter=2"] args = parser.parse_args(fake_argv) assert args.main_arg == "abc" assert args.my_parameter == 2
def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser: """Extends existing argparse by default `LightningDataModule` attributes. Example:: parser = ArgumentParser(add_help=False) parser = LightningDataModule.add_argparse_args(parser) """ return add_argparse_args(cls, parent_parser, **kwargs)
def test_add_argparse_args_no_argument_group(): """Tests that ``add_argparse_args(..., use_argument_group=False)`` (old workflow) handles argument groups correctly, and can be parsed.""" parser = ArgumentParser() parser.add_argument("--main_arg", type=str, default="") parser_old = parser # For testing. parser = add_argparse_args(AddArgparseArgsExampleClass, parser, use_argument_group=False) assert parser is not parser_old # Check arguments. help_text = extract_help_text(parser) assert "--main_arg" in help_text assert "--my_parameter" in help_text assert "AddArgparseArgsExampleClass:" not in help_text fake_argv = ["--main_arg=abc", "--my_parameter=2"] args = parser.parse_args(fake_argv) assert args.main_arg == "abc" assert args.my_parameter == 2
def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser: return add_argparse_args(cls, parent_parser, **kwargs)
def add_argparse_args(cls, *args, **kwargs) -> ArgumentParser: # the lightning trainer implementation does not support subclasses. # context: https://github.com/PyTorchLightning/lightning-flash/issues/342#issuecomment-848892447 return add_argparse_args(PlTrainer, *args, **kwargs)
def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser: """Extends existing argparse by default `LightningDataModule` attributes.""" return add_argparse_args(cls, parent_parser, **kwargs)
def test_negative_add_argparse_args(): with pytest.raises(RuntimeError, match="Please only pass an ArgumentParser instance."): parser = ArgumentParser() add_argparse_args(AddArgparseArgsExampleClass, parser.add_argument_group("bad workflow"))
def test_add_argparse_args_invalid(): """Test that `add_argparse_args` doesn't raise `TypeError` when a class has args typed as `typing.Generic` in Python 3.6.""" add_argparse_args(AddArgparseArgsExampleClassGeneric, ArgumentParser())