def test_mapping_class_typehint(self):
        class A:
            pass

        class B:
            def __init__(
                self,
                class_map: Mapping[str, A],
                int_list: List[int],
            ):
                self.class_map = class_map
                self.int_list = int_list

        with mock_module(A, B) as module:
            parser = ArgumentParser(error_handler=None)
            parser.add_class_arguments(B, 'b')

            config = {
                'b': {
                    'class_map': {
                        'one': {
                            'class_path': f'{module}.A'
                        },
                    },
                    'int_list': [1],
                },
            }

            cfg = parser.parse_object(config)
            self.assertEqual(cfg.b.class_map,
                             {'one': Namespace(class_path=f'{module}.A')})
            self.assertEqual(cfg.b.int_list, [1])

            cfg_init = parser.instantiate_classes(cfg)
            self.assertIsInstance(cfg_init.b, B)
            self.assertIsInstance(cfg_init.b.class_map, dict)
            self.assertIsInstance(cfg_init.b.class_map['one'], A)

            config['b']['int_list'] = config['b']['class_map']
            self.assertRaises(ParserError, lambda: parser.parse_object(config))
예제 #2
0
    def test_ActionJsonnet_parse(self):
        parser = ArgumentParser(error_handler=None)
        parser.add_argument('--ext_vars', action=ActionJsonnetExtVars())

        cfg = parser.parse_args(['--ext_vars', '{"param": 123}'])
        parsed = ActionJsonnet(schema=None).parse(example_2_jsonnet,
                                                  ext_vars=cfg.ext_vars)
        self.assertEqual(123, parsed['param'])
        self.assertEqual(9, len(parsed['records']))
        self.assertEqual('#8', parsed['records'][-2]['ref'])
        self.assertEqual(15.5, parsed['records'][-2]['val'])

        cfg2 = parser.parse_object({'ext_vars': {'param': 123}})
        self.assertEqual(cfg.ext_vars, cfg2.ext_vars)
예제 #3
0
    def test_add_argument_type(self):
        TenToTwenty = restricted_number_type('TenToTwenty', int, [('>=', 10),
                                                                  ('<=', 20)])

        def gt0_or_off(x):
            return x if x == 'off' else PositiveInt(x)

        parser = ArgumentParser(error_handler=None)
        parser.add_argument('--le0', type=NonNegativeFloat)
        parser.add_argument('--f10t20', type=TenToTwenty, nargs='+')
        parser.add_argument('--gt0_or_off', type=gt0_or_off)
        parser.add_argument('--multi_gt0_or_off', type=gt0_or_off, nargs='+')

        self.assertEqual(0.0, parser.parse_args(['--le0', '0']).le0)
        self.assertEqual(5.6, parser.parse_args(['--le0', '5.6']).le0)
        self.assertRaises(ParserError,
                          lambda: parser.parse_args(['--le0', '-2.1']))

        self.assertEqual([11, 14, 16],
                         parser.parse_args(['--f10t20', '11', '14',
                                            '16']).f10t20)
        self.assertRaises(ParserError,
                          lambda: parser.parse_args(['--f10t20', '9']))
        self.assertRaises(ParserError,
                          lambda: parser.parse_args(['--f10t20', '21']))
        self.assertRaises(ParserError,
                          lambda: parser.parse_args(['--f10t20', '10.5']))

        self.assertEqual(1,
                         parser.parse_args(['--gt0_or_off', '1']).gt0_or_off)
        self.assertEqual('off',
                         parser.parse_args(['--gt0_or_off', 'off']).gt0_or_off)
        self.assertRaises(ParserError,
                          lambda: parser.parse_args(['--gt0_or_off', '0']))
        self.assertRaises(ParserError,
                          lambda: parser.parse_args(['--gt0_or_off', 'on']))

        self.assertEqual([1, 'off'],
                         parser.parse_args(['--multi_gt0_or_off', '1',
                                            'off']).multi_gt0_or_off)
        self.assertRaises(
            ParserError,
            lambda: parser.parse_args(['--multi_gt0_or_off', '1', '0']))
        self.assertRaises(
            ParserError,
            lambda: parser.parse_object({'multi_gt0_or_off': [1, 0]}))