Пример #1
0
 def add_subcommand_from_function(self,
                                  subcommands,
                                  function,
                                  function_name=None):
     subcommand = ArgumentParser()
     if get_kwarg_name(function) == "data_module_kwargs":
         datamodule_function = class_from_function(
             function, return_type=self.local_datamodule_class)
         subcommand.add_class_arguments(
             datamodule_function,
             fail_untyped=False,
             skip={
                 "self",
                 "train_dataset",
                 "val_dataset",
                 "test_dataset",
                 "predict_dataset",
                 "train_input",
                 "val_input",
                 "test_input",
                 "predict_input",
                 "input",
                 "input_transform",
             },
         )
     else:
         datamodule_function = class_from_function(
             drop_kwargs(function), return_type=self.local_datamodule_class)
         subcommand.add_class_arguments(datamodule_function,
                                        fail_untyped=False)
     subcommand_name = function_name or function.__name__
     subcommands.add_subcommand(subcommand_name, subcommand)
     self._subcommand_builders[subcommand_name] = function
    def test_class_type_with_default_config_files(self):
        config = {
            'class_path': 'calendar.Calendar',
            'init_args': {
                'firstweekday': 3
            },
        }
        config_path = os.path.join(self.tmpdir, 'config.yaml')
        with open(config_path, 'w') as f:
            json.dump({'data': {'cal': config}}, f)

        class MyClass:
            def __init__(self, cal: Optional[Calendar] = None, val: int = 2):
                self.cal = cal

        parser = ArgumentParser(error_handler=None,
                                default_config_files=[config_path])
        parser.add_argument('--op', default='from default')
        parser.add_class_arguments(MyClass, 'data')

        cfg = parser.get_defaults()
        self.assertEqual(config_path, str(cfg['__default_config__']))
        self.assertEqual(cfg.data.cal.as_dict(), config)
        dump = parser.dump(cfg)
        self.assertIn('class_path: calendar.Calendar\n', dump)
        self.assertIn('firstweekday: 3\n', dump)

        cfg = parser.parse_args([])
        self.assertEqual(cfg.data.cal.as_dict(), config)
        cfg = parser.parse_args(['--data.cal.class_path=calendar.Calendar'],
                                defaults=False)
        self.assertEqual(cfg.data.cal,
                         Namespace(class_path='calendar.Calendar'))
    def test_linking_deep_targets_mapping(self):
        class D:
            pass

        class A:
            def __init__(self, d: D) -> None:
                self.d = d

        class BSuper:
            pass

        class BSub(BSuper):
            def __init__(self, a_map: Mapping[str, A]) -> None:
                self.a_map = a_map

        class C:
            def fn(self) -> D:
                return D()

        with mock_module(D, A, BSuper, BSub, C) as module:
            config = {
                "b": {
                    "class_path": f"{module}.BSub",
                    "init_args": {
                        "a_map": {
                            "name": {
                                "class_path": f"{module}.A",
                            },
                        },
                    },
                },
                "c": {},
            }
            config_path = os.path.join(self.tmpdir, 'config.yaml')
            with open(config_path, 'w') as f:
                yaml.safe_dump(config, f)

            parser = ArgumentParser()
            parser.add_argument("--config", action=ActionConfigFile)
            parser.add_subclass_arguments(BSuper,
                                          nested_key="b",
                                          required=True)
            parser.add_class_arguments(C, nested_key="c")
            parser.link_arguments("c",
                                  "b.init_args.a_map.name.init_args.d",
                                  compute_fn=C.fn,
                                  apply_on="instantiate")

            config = parser.parse_args(["--config", config_path])
            config_init = parser.instantiate_classes(config)
            self.assertIsInstance(config_init["b"].a_map["name"].d, D)

            config_init = parser.instantiate_classes(config)
            self.assertIsInstance(config_init["b"].a_map["name"].d, D)
Пример #4
0
    def test_parser_mode_jsonnet_subconfigs_issue_125(self):
        os.mkdir('conf')
        with open(os.path.join('conf', 'name.libsonnet'), 'w') as f:
            f.write('"Mike"')
        config_path = os.path.join('conf', 'test.jsonnet')
        with open(config_path, 'w') as f:
            f.write(
                'local name = import "name.libsonnet"; {"name": name, "prize": 80}'
            )

        class Class:
            def __init__(self, name: str = 'Lucky', prize: int = 100):
                pass

        parser = ArgumentParser(parser_mode='jsonnet', error_handler=None)
        parser.add_class_arguments(Class, 'group', sub_configs=True)

        cfg = parser.parse_args([f'--group={config_path}'])
        self.assertEqual(cfg.group.name, 'Mike')
        self.assertEqual(cfg.group.prize, 80)
    def test_mapping_class_typehint(self):
        class A:
            pass

        class B:
            def __init__(
                self,
                class_map: Mapping[str, A],
                int_list: List[int],
            ):
                self.class_map = class_map
                self.int_list = int_list

        with mock_module(A, B) as module:
            parser = ArgumentParser(error_handler=None)
            parser.add_class_arguments(B, 'b')

            config = {
                'b': {
                    'class_map': {
                        'one': {
                            'class_path': f'{module}.A'
                        },
                    },
                    'int_list': [1],
                },
            }

            cfg = parser.parse_object(config)
            self.assertEqual(cfg.b.class_map,
                             {'one': Namespace(class_path=f'{module}.A')})
            self.assertEqual(cfg.b.int_list, [1])

            cfg_init = parser.instantiate_classes(cfg)
            self.assertIsInstance(cfg_init.b, B)
            self.assertIsInstance(cfg_init.b.class_map, dict)
            self.assertIsInstance(cfg_init.b.class_map['one'], A)

            config['b']['int_list'] = config['b']['class_map']
            self.assertRaises(ParserError, lambda: parser.parse_object(config))
Пример #6
0
    def _adapt_types(val,
                     annotation,
                     subschemas,
                     reverse=False,
                     instantiate_classes=False):
        def validate_adapt(v, subschema):
            if subschema is not None:
                subannotation, subvalidator, subsubschemas = subschema
                if reverse:
                    v = ActionJsonSchema._adapt_types(v, subannotation,
                                                      subsubschemas, reverse,
                                                      instantiate_classes)
                else:
                    try:
                        if subvalidator is not None and not instantiate_classes:
                            subvalidator.validate(v)
                        v = ActionJsonSchema._adapt_types(
                            v, subannotation, subsubschemas, reverse,
                            instantiate_classes)
                    except jsonschemaValidationError:
                        pass
            return v

        if subschemas is None:
            subschemas = []

        if _issubclass(annotation, Enum):
            if reverse and isinstance(val, annotation):
                val = val.name
            elif not reverse and val in annotation.__members__:
                val = annotation[val]

        elif _issubclass(annotation, Path):
            if reverse and isinstance(val, annotation):
                val = str(val)
            elif not reverse:
                val = annotation(val)

        elif not hasattr(annotation, '__origin__'):
            if not reverse and \
               not _issubclass(annotation, (str, int, float)) and \
               isinstance(val, dict) and \
               'class_path' in val:
                try:
                    val_class = import_object(val['class_path'])
                    assert _issubclass(
                        val_class,
                        annotation), 'Not a subclass of ' + annotation.__name__
                    if 'init_args' in val:
                        from jsonargparse import ArgumentParser
                        parser = ArgumentParser(error_handler=None,
                                                parse_as_dict=True)
                        parser.add_class_arguments(val_class)
                        parser.check_config(val['init_args'])
                        if instantiate_classes:
                            init_args = parser.instantiate_subclasses(
                                val['init_args'])
                            val = val_class(**init_args)  # pylint: disable=not-a-mapping
                    elif instantiate_classes:
                        val = val_class()
                except (ImportError, ModuleNotFound, AttributeError,
                        AssertionError, ParserError) as ex:
                    raise ParserError('Problem with given class_path "' +
                                      val['class_path'] + '" :: ' +
                                      str(ex)) from ex
            return val

        elif annotation.__origin__ == Union:
            for subschema in subschemas:
                val = validate_adapt(val, subschema)

        elif annotation.__origin__ in {Tuple, tuple, Set, set} and isinstance(
                val, (list, tuple, set)):
            if reverse:
                val = list(val)
            for n, v in enumerate(val):
                if n < len(subschemas) and subschemas[n] is not None:
                    for subschema in subschemas[n]:
                        val[n] = validate_adapt(v, subschema)
            if not reverse:
                val = tuple(val) if annotation.__origin__ in {Tuple, tuple
                                                              } else set(val)

        elif annotation.__origin__ in {
                List, list, Set, set, Iterable, Sequence
        } and isinstance(val, list):
            for n, v in enumerate(val):
                for subschema in subschemas:
                    val[n] = validate_adapt(v, subschema)

        elif annotation.__origin__ in {Dict, dict} and isinstance(val, dict):
            if annotation.__args__[0] == int:
                cast = str if reverse else int
                val = {cast(k): v for k, v in val.items()}
            if annotation.__args__[1] not in typesmap:
                for k, v in val.items():
                    for subschema in subschemas:
                        val[k] = validate_adapt(v, subschema)

        return val