def add_subcommand_from_function(self, subcommands, function, function_name=None): subcommand = ArgumentParser() if get_kwarg_name(function) == "data_module_kwargs": datamodule_function = class_from_function( function, return_type=self.local_datamodule_class) subcommand.add_class_arguments( datamodule_function, fail_untyped=False, skip={ "self", "train_dataset", "val_dataset", "test_dataset", "predict_dataset", "train_input", "val_input", "test_input", "predict_input", "input", "input_transform", }, ) else: datamodule_function = class_from_function( drop_kwargs(function), return_type=self.local_datamodule_class) subcommand.add_class_arguments(datamodule_function, fail_untyped=False) subcommand_name = function_name or function.__name__ subcommands.add_subcommand(subcommand_name, subcommand) self._subcommand_builders[subcommand_name] = function
def test_class_type_with_default_config_files(self): config = { 'class_path': 'calendar.Calendar', 'init_args': { 'firstweekday': 3 }, } config_path = os.path.join(self.tmpdir, 'config.yaml') with open(config_path, 'w') as f: json.dump({'data': {'cal': config}}, f) class MyClass: def __init__(self, cal: Optional[Calendar] = None, val: int = 2): self.cal = cal parser = ArgumentParser(error_handler=None, default_config_files=[config_path]) parser.add_argument('--op', default='from default') parser.add_class_arguments(MyClass, 'data') cfg = parser.get_defaults() self.assertEqual(config_path, str(cfg['__default_config__'])) self.assertEqual(cfg.data.cal.as_dict(), config) dump = parser.dump(cfg) self.assertIn('class_path: calendar.Calendar\n', dump) self.assertIn('firstweekday: 3\n', dump) cfg = parser.parse_args([]) self.assertEqual(cfg.data.cal.as_dict(), config) cfg = parser.parse_args(['--data.cal.class_path=calendar.Calendar'], defaults=False) self.assertEqual(cfg.data.cal, Namespace(class_path='calendar.Calendar'))
def test_linking_deep_targets_mapping(self): class D: pass class A: def __init__(self, d: D) -> None: self.d = d class BSuper: pass class BSub(BSuper): def __init__(self, a_map: Mapping[str, A]) -> None: self.a_map = a_map class C: def fn(self) -> D: return D() with mock_module(D, A, BSuper, BSub, C) as module: config = { "b": { "class_path": f"{module}.BSub", "init_args": { "a_map": { "name": { "class_path": f"{module}.A", }, }, }, }, "c": {}, } config_path = os.path.join(self.tmpdir, 'config.yaml') with open(config_path, 'w') as f: yaml.safe_dump(config, f) parser = ArgumentParser() parser.add_argument("--config", action=ActionConfigFile) parser.add_subclass_arguments(BSuper, nested_key="b", required=True) parser.add_class_arguments(C, nested_key="c") parser.link_arguments("c", "b.init_args.a_map.name.init_args.d", compute_fn=C.fn, apply_on="instantiate") config = parser.parse_args(["--config", config_path]) config_init = parser.instantiate_classes(config) self.assertIsInstance(config_init["b"].a_map["name"].d, D) config_init = parser.instantiate_classes(config) self.assertIsInstance(config_init["b"].a_map["name"].d, D)
def test_parser_mode_jsonnet_subconfigs_issue_125(self): os.mkdir('conf') with open(os.path.join('conf', 'name.libsonnet'), 'w') as f: f.write('"Mike"') config_path = os.path.join('conf', 'test.jsonnet') with open(config_path, 'w') as f: f.write( 'local name = import "name.libsonnet"; {"name": name, "prize": 80}' ) class Class: def __init__(self, name: str = 'Lucky', prize: int = 100): pass parser = ArgumentParser(parser_mode='jsonnet', error_handler=None) parser.add_class_arguments(Class, 'group', sub_configs=True) cfg = parser.parse_args([f'--group={config_path}']) self.assertEqual(cfg.group.name, 'Mike') self.assertEqual(cfg.group.prize, 80)
def test_mapping_class_typehint(self): class A: pass class B: def __init__( self, class_map: Mapping[str, A], int_list: List[int], ): self.class_map = class_map self.int_list = int_list with mock_module(A, B) as module: parser = ArgumentParser(error_handler=None) parser.add_class_arguments(B, 'b') config = { 'b': { 'class_map': { 'one': { 'class_path': f'{module}.A' }, }, 'int_list': [1], }, } cfg = parser.parse_object(config) self.assertEqual(cfg.b.class_map, {'one': Namespace(class_path=f'{module}.A')}) self.assertEqual(cfg.b.int_list, [1]) cfg_init = parser.instantiate_classes(cfg) self.assertIsInstance(cfg_init.b, B) self.assertIsInstance(cfg_init.b.class_map, dict) self.assertIsInstance(cfg_init.b.class_map['one'], A) config['b']['int_list'] = config['b']['class_map'] self.assertRaises(ParserError, lambda: parser.parse_object(config))
def _adapt_types(val, annotation, subschemas, reverse=False, instantiate_classes=False): def validate_adapt(v, subschema): if subschema is not None: subannotation, subvalidator, subsubschemas = subschema if reverse: v = ActionJsonSchema._adapt_types(v, subannotation, subsubschemas, reverse, instantiate_classes) else: try: if subvalidator is not None and not instantiate_classes: subvalidator.validate(v) v = ActionJsonSchema._adapt_types( v, subannotation, subsubschemas, reverse, instantiate_classes) except jsonschemaValidationError: pass return v if subschemas is None: subschemas = [] if _issubclass(annotation, Enum): if reverse and isinstance(val, annotation): val = val.name elif not reverse and val in annotation.__members__: val = annotation[val] elif _issubclass(annotation, Path): if reverse and isinstance(val, annotation): val = str(val) elif not reverse: val = annotation(val) elif not hasattr(annotation, '__origin__'): if not reverse and \ not _issubclass(annotation, (str, int, float)) and \ isinstance(val, dict) and \ 'class_path' in val: try: val_class = import_object(val['class_path']) assert _issubclass( val_class, annotation), 'Not a subclass of ' + annotation.__name__ if 'init_args' in val: from jsonargparse import ArgumentParser parser = ArgumentParser(error_handler=None, parse_as_dict=True) parser.add_class_arguments(val_class) parser.check_config(val['init_args']) if instantiate_classes: init_args = parser.instantiate_subclasses( val['init_args']) val = val_class(**init_args) # pylint: disable=not-a-mapping elif instantiate_classes: val = val_class() except (ImportError, ModuleNotFound, AttributeError, AssertionError, ParserError) as ex: raise ParserError('Problem with given class_path "' + val['class_path'] + '" :: ' + str(ex)) from ex return val elif annotation.__origin__ == Union: for subschema in subschemas: val = validate_adapt(val, subschema) elif annotation.__origin__ in {Tuple, tuple, Set, set} and isinstance( val, (list, tuple, set)): if reverse: val = list(val) for n, v in enumerate(val): if n < len(subschemas) and subschemas[n] is not None: for subschema in subschemas[n]: val[n] = validate_adapt(v, subschema) if not reverse: val = tuple(val) if annotation.__origin__ in {Tuple, tuple } else set(val) elif annotation.__origin__ in { List, list, Set, set, Iterable, Sequence } and isinstance(val, list): for n, v in enumerate(val): for subschema in subschemas: val[n] = validate_adapt(v, subschema) elif annotation.__origin__ in {Dict, dict} and isinstance(val, dict): if annotation.__args__[0] == int: cast = str if reverse else int val = {cast(k): v for k, v in val.items()} if annotation.__args__[1] not in typesmap: for k, v in val.items(): for subschema in subschemas: val[k] = validate_adapt(v, subschema) return val