def test_dict_union(self): class MyEnum(Enum): ab = 1 parser = ArgumentParser(error_handler=None) parser.add_argument('--dict1', type=Dict[int, Optional[Union[float, MyEnum]]]) parser.add_argument('--dict2', type=Dict[str, Union[bool, Path_fc]]) cfg = parser.parse_args( ['--dict1={"2":4.5, "6":"ab"}', '--dict2={"a":true, "b":"f"}']) self.assertEqual({2: 4.5, 6: MyEnum.ab}, cfg['dict1']) self.assertEqual({'a': True, 'b': 'f'}, cfg['dict2']) self.assertIsInstance(cfg['dict2']['b'], Path) self.assertEqual({5: None}, parser.parse_args(['--dict1={"5":null}'])['dict1']) self.assertRaises(ParserError, lambda: parser.parse_args(['--dict1=["a", "b"]'])) cfg = yaml.safe_load(parser.dump(cfg)) self.assertEqual( { 'dict1': { '2': 4.5, '6': 'ab' }, 'dict2': { 'a': True, 'b': 'f' } }, cfg)
def test_subcommand_with_subclass_default_override_lightning_issue_10859( self): class Arch: def __init__(self, a: int = 1): pass class ArchB(Arch): def __init__(self, a: int = 2, b: int = 3): pass class ArchC(Arch): def __init__(self, a: int = 4, c: int = 5): pass parser = ArgumentParser(error_handler=None) parser_subcommands = parser.add_subcommands() subparser = ArgumentParser() subparser.add_argument('--arch', type=Arch) with mock_module(Arch, ArchB, ArchC) as module: default = {'class_path': f'{module}.ArchB'} value = { 'class_path': f'{module}.ArchC', 'init_args': { 'a': 10, 'c': 11 } } subparser.set_defaults(arch=default) parser_subcommands.add_subcommand('fit', subparser) cfg = parser.parse_args(['fit', f'--arch={json.dumps(value)}']) self.assertEqual(cfg.fit.arch.as_dict(), value)
def test_enum(self): class MyEnum(Enum): A = 1 B = 2 C = 3 parser = ArgumentParser(error_handler=None) parser.add_argument('--enum', type=MyEnum, default=MyEnum.C, help='Description') for val in ['A', 'B', 'C']: self.assertEqual(MyEnum[val], parser.parse_args(['--enum=' + val]).enum) for val in ['X', 'b', 2]: self.assertRaises( ParserError, lambda: parser.parse_args(['--enum=' + str(val)])) cfg = parser.parse_args(['--enum=C'], with_meta=False) self.assertEqual('enum: C\n', parser.dump(cfg)) help_str = StringIO() parser.print_help(help_str) self.assertIn('Description (type: MyEnum, default: C)', help_str.getvalue())
def test_list_path(self): parser = ArgumentParser() parser.add_argument('--paths', type=List[Path_fc]) cfg = parser.parse_args(['--paths=["file1", "file2"]']) self.assertEqual(['file1', 'file2'], cfg.paths) self.assertIsInstance(cfg.paths[0], Path) self.assertIsInstance(cfg.paths[1], Path)
def test_list(self): for list_type in [Iterable, List, Sequence]: with self.subTest(str(list_type)): parser = ArgumentParser() parser.add_argument('--list', type=list_type[int]) cfg = parser.parse_args(['--list=[1, 2]']) self.assertEqual([1, 2], cfg.list)
def test_ActionJsonnet(self): parser = ArgumentParser(default_meta=False, error_handler=None) parser.add_argument('--input.ext_vars', action=ActionJsonnetExtVars()) parser.add_argument('--input.jsonnet', action=ActionJsonnet( ext_vars='input.ext_vars', schema=json.dumps(example_schema))) cfg2 = parser.parse_args([ '--input.ext_vars', '{"param": 123}', '--input.jsonnet', example_2_jsonnet ]) self.assertEqual(123, cfg2.input.jsonnet['param']) self.assertEqual(9, len(cfg2.input.jsonnet['records'])) self.assertEqual('#8', cfg2.input.jsonnet['records'][-2]['ref']) self.assertEqual(15.5, cfg2.input.jsonnet['records'][-2]['val']) cfg1 = parser.parse_args(['--input.jsonnet', example_1_jsonnet]) self.assertEqual(cfg1.input.jsonnet['records'], cfg2.input.jsonnet['records']) self.assertRaises( ParserError, lambda: parser.parse_args([ '--input.ext_vars', '{"param": "a"}', '--input.jsonnet', example_2_jsonnet ])) self.assertRaises( ParserError, lambda: parser.parse_args(['--input.jsonnet', example_2_jsonnet])) self.assertRaises(ValueError, lambda: ActionJsonnet(ext_vars=2)) self.assertRaises( ValueError, lambda: ActionJsonnet(schema='.' + json.dumps(example_schema)))
def test_class_type_with_default_config_files(self): config = { 'class_path': 'calendar.Calendar', 'init_args': { 'firstweekday': 3 }, } config_path = os.path.join(self.tmpdir, 'config.yaml') with open(config_path, 'w') as f: json.dump({'data': {'cal': config}}, f) class MyClass: def __init__(self, cal: Optional[Calendar] = None, val: int = 2): self.cal = cal parser = ArgumentParser(error_handler=None, default_config_files=[config_path]) parser.add_argument('--op', default='from default') parser.add_class_arguments(MyClass, 'data') cfg = parser.get_defaults() self.assertEqual(config_path, str(cfg['__default_config__'])) self.assertEqual(cfg.data.cal.as_dict(), config) dump = parser.dump(cfg) self.assertIn('class_path: calendar.Calendar\n', dump) self.assertIn('firstweekday: 3\n', dump) cfg = parser.parse_args([]) self.assertEqual(cfg.data.cal.as_dict(), config) cfg = parser.parse_args(['--data.cal.class_path=calendar.Calendar'], defaults=False) self.assertEqual(cfg.data.cal, Namespace(class_path='calendar.Calendar'))
def get_parser_lv2(): parser_lv2 = ArgumentParser(description='parser_lv2 description') parser_lv2.add_argument('--a1', help='lv2_a1 help') group_lv2 = parser_lv2.add_argument_group( description='group_lv2 description') group_lv2.add_argument('--a2', help='lv2_a2 help') return parser_lv2
def test_ActionYesNo_parse_env(self): parser = example_parser() self.assertEqual( True, parser.parse_env({ 'APP_BOOLS__DEF_FALSE': 'true' }).bools.def_false) self.assertEqual( True, parser.parse_env({ 'APP_BOOLS__DEF_FALSE': 'yes' }).bools.def_false) self.assertEqual( False, parser.parse_env({ 'APP_BOOLS__DEF_TRUE': 'false' }).bools.def_true) self.assertEqual( False, parser.parse_env({ 'APP_BOOLS__DEF_TRUE': 'no' }).bools.def_true) parser = ArgumentParser(default_env=True, env_prefix='APP') parser.add_argument('--op', action=ActionYesNo, default=False) self.assertEqual(True, parser.parse_env({'APP_OP': 'true'}).op)
def get_config_parser(): parser = ArgumentParser() parser.add_argument("--factor", type=int, default=2, help="Factor to multiply") return parser
def test_typehint_serialize_list(self): parser = ArgumentParser() action = parser.add_argument('--list', type=Union[PositiveInt, List[PositiveInt]]) self.assertEqual([1, 2], action.serialize([PositiveInt(1), PositiveInt(2)])) self.assertRaises(ValueError, lambda: action.serialize([1, -2]))
def test_list_enum(self): class MyEnum(Enum): ab = 0 xy = 1 parser = ArgumentParser(error_handler=None) parser.add_argument('--list', type=List[MyEnum]) self.assertEqual([MyEnum.xy, MyEnum.ab], parser.parse_args(['--list=["xy", "ab"]']).list)
def test_typehint_serialize_enum(self): class MyEnum(Enum): a = 1 b = 2 parser = ArgumentParser() action = parser.add_argument('--enum', type=Optional[MyEnum]) self.assertEqual('b', action.serialize(MyEnum.b)) self.assertRaises(ValueError, lambda: action.serialize('x'))
def test_nested_tuples(self): parser = ArgumentParser(error_handler=None) parser.add_argument('--tuple', type=Tuple[Tuple[str, str], Tuple[Tuple[int, float], Tuple[int, float]]]) cfg = parser.parse_args( ['--tuple=[["foo", "bar"], [[1, 2.02], [3, 3.09]]]']) self.assertEqual((('foo', 'bar'), ((1, 2.02), (3, 3.09))), cfg.tuple)
def test_class_type_subclass_given_by_name_issue_84(self): class LocalCalendar(Calendar): pass parser = ArgumentParser() parser.add_argument('--op', type=Union[Calendar, GzipFile, None]) cfg = parser.parse_args(['--op=TextCalendar']) self.assertEqual(cfg.op.class_path, 'calendar.TextCalendar') out = StringIO() parser.print_help(out) for class_path in [ 'calendar.Calendar', 'calendar.TextCalendar', 'gzip.GzipFile' ]: self.assertIn(class_path, out.getvalue()) self.assertNotIn('LocalCalendar', out.getvalue()) class HTMLCalendar(Calendar): pass with mock_module(HTMLCalendar) as module: err = StringIO() with redirect_stderr(err), self.assertRaises(SystemExit): parser.parse_args(['--op.help=HTMLCalendar']) self.assertIn('Give the full class path to avoid ambiguity', err.getvalue()) self.assertIn(f'{module}.HTMLCalendar', err.getvalue())
def test_nargs_questionmark(self): parser = ArgumentParser(error_handler=None) parser.add_argument('p1') parser.add_argument('p2', nargs='?', type=OpenUnitInterval) self.assertIsNone(parser.parse_args(['a']).p2) self.assertEqual(0.5, parser.parse_args(['a', '0.5']).p2) self.assertRaises(ParserError, lambda: parser.parse_args(['a', 'b']))
def read_config(args: Optional[List[str]]) -> SimpleNamespace: """ Parse command line arguments and return a JSON object. :returns: benchmark configuration. """ parser = ArgumentParser() parser.add_argument("config", action=ActionJsonSchema(schema=SCHEMA), help="%s") config = parser.parse_args(args).config return config
def test_ActionParser_failures(self): parser_lv2 = ArgumentParser() parser_lv2.add_argument('--op') parser = ArgumentParser(error_handler=None) parser.add_argument('--inner', action=ActionParser(parser=parser_lv2)) self.assertRaises( ValueError, lambda: parser.add_argument('--mistake', action=ActionParser(parser=parser))) self.assertRaises(ParserError, lambda: parser.parse_args(['--inner=1']))
def test_parser_mode_subparsers(self): subparser = ArgumentParser() parser = ArgumentParser() subcommands = parser.add_subcommands() subcommands.add_subcommand('sub', subparser) with unittest.mock.patch.dict('jsonargparse.loaders_dumpers.loaders'): set_loader('custom', yaml.safe_load) parser.parser_mode = 'custom' self.assertEqual('custom', parser.parser_mode) self.assertEqual('custom', subparser.parser_mode)
def test_ActionYesNo_old_bool(self): parser = ArgumentParser(error_handler=None) parser.add_argument('--val', nargs=1, action=ActionYesNo(no_prefix=None)) self.assertEqual(False, parser.get_defaults().val) self.assertEqual(True, parser.parse_args(['--val', 'true']).val) self.assertEqual(True, parser.parse_args(['--val', 'yes']).val) self.assertEqual(False, parser.parse_args(['--val', 'false']).val) self.assertEqual(False, parser.parse_args(['--val', 'no']).val) self.assertRaises(ParserError, lambda: parser.parse_args(['--val', '1']))
def test_ActionParser_conflict(self): parser_lv2 = ArgumentParser() parser_lv2.add_argument('--op') parser = ArgumentParser(error_handler=None) parser.add_argument('--inner.op') self.assertRaises( ValueError, lambda: parser.add_argument('--inner', action=ActionParser(parser_lv2)))
def test_class_path_override_with_default_config_files(self): class MyCalendar(Calendar): def __init__(self, *args, param: str = '0', **kwargs): super().__init__(*args, **kwargs) with mock_module(MyCalendar) as module: config = { 'class_path': f'{module}.MyCalendar', 'init_args': { 'firstweekday': 2, 'param': '1' }, } config_path = os.path.join(self.tmpdir, 'config.yaml') with open(config_path, 'w') as f: json.dump({'cal': config}, f) parser = ArgumentParser(error_handler=None, default_config_files=[config_path]) parser.add_argument('--cal', type=Optional[Calendar]) cfg = parser.instantiate_classes(parser.get_defaults()) self.assertIsInstance(cfg['cal'], MyCalendar) cfg = parser.parse_args([ '--cal={"class_path": "calendar.Calendar", "init_args": {"firstweekday": 3}}' ]) self.assertEqual(type(parser.instantiate_classes(cfg)['cal']), Calendar)
def test_class_type_subclass_nested_help(self): class Class: def __init__(self, cal: Calendar, p1: int = 0): self.cal = cal parser = ArgumentParser() parser.add_argument('--op', type=Class) for pattern in [r'[\s=]', r'\s']: with self.subTest('" "' if '=' in pattern else '"="'), mock_module( Class) as module: out = StringIO() args = re.split( pattern, f'--op.help={module}.Class --op.init_args.cal.help=TextCalendar' ) with redirect_stdout(out), self.assertRaises(SystemExit): parser.parse_args(args) self.assertIn('--op.init_args.cal.init_args.firstweekday', out.getvalue()) with self.subTest('invalid'), mock_module(Class) as module: err = StringIO() with redirect_stderr(err), self.assertRaises(SystemExit): parser.parse_args( [f'--op.help={module}.Class', '--op.init_args.p1=1']) self.assertIn('Expected a nested --*.help option', err.getvalue())
def test_parser_mode_omegaconf_in_subcommands(self): subparser = ArgumentParser() subparser.add_argument('--config', action=ActionConfigFile) subparser.add_argument('--source', type=str) subparser.add_argument('--target', type=str) parser = ArgumentParser(error_handler=None, parser_mode='omegaconf') subcommands = parser.add_subcommands() subcommands.add_subcommand('sub', subparser) config = { 'source': 'hello', 'target': '${source}', } cfg = parser.parse_args(['sub', f'--config={yaml_dump(config)}']) self.assertEqual(cfg.sub.target, 'hello')
def read_config(args: Optional[List[str]]) -> SimpleNamespace: """ Parse command line arguments and return a JSON object. :returns: benchmark configuration. """ parser = ArgumentParser() parser.add_argument("config", action=ActionJsonSchema(schema=SCHEMA), help="%s") config = parser.parse_args(args).config # default num_warmup to half of num_sample if not hasattr(config, "num_warmup"): config.num_warmup = config.iterations // 2 return config
def make_grouped_parser(arg_groups, check_required=True): """ Builds an argument parser, with args divided over groups. :param arg_groups: a dict of dicts {group_name: args}. See ./args for format. :param check_required: Check for required arguments, defaults to True :return: argument parser containing the args from arg_groups. """ #parser = ArgumentParser(formatter_class='default_argparse') parser = ArgumentParser() parser.add_argument('--hparams_file', action=ActionConfigFile) for group_name, arg_group in arg_groups.items(): add_arg_group(parser, group_name, arg_group, check_required=check_required) return parser
def __init__(self): # Initialise parser self.parser = ArgumentParser( description= "A spatial & temporal-based two-stream convolutional neural network for recognising great ape behaviour." ) # Take in config self.parser.add_argument("--config", action=ActionConfigFile) # Add all arguments to parser self.add_general_arguments() self.add_hyperparameter_arguments() self.add_loss_arguments() self.add_lstm_arguments() self.add_dataset_arguments() self.add_dataloader_arguments() self.add_augmentation_arguments() self.add_path_arguments() self.add_frequency_arguments() # Create config self.config = self.parser.parse_args() if not self.config.name: print("Please specify model name in order to train/evaluate") exit() self.config.dataloader.worker_count = cpu_count() if self.config.mode == "test": self.config.dataloader.batch_size = 1 self.config.dataloader.shuffle = False self.config.dataset.sequence_length = 20 if self.config.bucket == "": self.config.bucket = None self.config.paths.annotations = str(self.config.paths.annotations) self.config.paths.checkpoints = str(self.config.paths.checkpoints) self.config.paths.classes = str(self.config.paths.classes) self.config.paths.frames = str(self.config.paths.frames) self.config.paths.logs = str(self.config.paths.logs) self.config.paths.output = str(self.config.paths.output) self.config.paths.splits = str(self.config.paths.splits)
def test_add_argument_type(self): FourDigits = restricted_string_type('FourDigits', '^[0-9]{4}$') parser = ArgumentParser(error_handler=None) parser.add_argument('--op', type=FourDigits) self.assertEqual('1234', parser.parse_args(['--op', '1234']).op) self.assertRaises(ParserError, lambda: parser.parse_args(['--op', '123'])) self.assertRaises(ParserError, lambda: parser.parse_args(['--op', '12345'])) self.assertRaises(ParserError, lambda: parser.parse_args(['--op', 'abcd']))
def test_default_path_unregistered_type(self): parser = ArgumentParser() parser.add_argument('--path', type=path_type('drw', skip_check=True), default=Path('test', mode='drw', skip_check=True)) cfg = parser.parse_args([]) self.assertEqual('path: test\n', parser.dump(cfg)) out = StringIO() parser.print_help(out) self.assertIn('(type: Path_drw_skip_check, default: test)', out.getvalue())
def test_list_union(self): class MyEnum(Enum): ab = 1 parser = ArgumentParser(error_handler=None) parser.add_argument('--list1', type=List[Union[float, str, type(None)]]) parser.add_argument('--list2', type=List[Union[int, MyEnum]]) self.assertEqual([1.2, 'ab'], parser.parse_args(['--list1=[1.2, "ab"]']).list1) self.assertEqual([3, MyEnum.ab], parser.parse_args(['--list2=[3, "ab"]']).list2) self.assertRaises( ParserError, lambda: parser.parse_args(['--list1={"a":1, "b":"2"}']))