コード例 #1
0
    def test_dict_union(self):
        class MyEnum(Enum):
            ab = 1

        parser = ArgumentParser(error_handler=None)
        parser.add_argument('--dict1',
                            type=Dict[int, Optional[Union[float, MyEnum]]])
        parser.add_argument('--dict2', type=Dict[str, Union[bool, Path_fc]])
        cfg = parser.parse_args(
            ['--dict1={"2":4.5, "6":"ab"}', '--dict2={"a":true, "b":"f"}'])
        self.assertEqual({2: 4.5, 6: MyEnum.ab}, cfg['dict1'])
        self.assertEqual({'a': True, 'b': 'f'}, cfg['dict2'])
        self.assertIsInstance(cfg['dict2']['b'], Path)
        self.assertEqual({5: None},
                         parser.parse_args(['--dict1={"5":null}'])['dict1'])
        self.assertRaises(ParserError,
                          lambda: parser.parse_args(['--dict1=["a", "b"]']))
        cfg = yaml.safe_load(parser.dump(cfg))
        self.assertEqual(
            {
                'dict1': {
                    '2': 4.5,
                    '6': 'ab'
                },
                'dict2': {
                    'a': True,
                    'b': 'f'
                }
            }, cfg)
コード例 #2
0
    def test_subcommand_with_subclass_default_override_lightning_issue_10859(
            self):
        class Arch:
            def __init__(self, a: int = 1):
                pass

        class ArchB(Arch):
            def __init__(self, a: int = 2, b: int = 3):
                pass

        class ArchC(Arch):
            def __init__(self, a: int = 4, c: int = 5):
                pass

        parser = ArgumentParser(error_handler=None)
        parser_subcommands = parser.add_subcommands()
        subparser = ArgumentParser()
        subparser.add_argument('--arch', type=Arch)

        with mock_module(Arch, ArchB, ArchC) as module:
            default = {'class_path': f'{module}.ArchB'}
            value = {
                'class_path': f'{module}.ArchC',
                'init_args': {
                    'a': 10,
                    'c': 11
                }
            }

            subparser.set_defaults(arch=default)
            parser_subcommands.add_subcommand('fit', subparser)

            cfg = parser.parse_args(['fit', f'--arch={json.dumps(value)}'])
            self.assertEqual(cfg.fit.arch.as_dict(), value)
コード例 #3
0
    def test_enum(self):
        class MyEnum(Enum):
            A = 1
            B = 2
            C = 3

        parser = ArgumentParser(error_handler=None)
        parser.add_argument('--enum',
                            type=MyEnum,
                            default=MyEnum.C,
                            help='Description')

        for val in ['A', 'B', 'C']:
            self.assertEqual(MyEnum[val],
                             parser.parse_args(['--enum=' + val]).enum)
        for val in ['X', 'b', 2]:
            self.assertRaises(
                ParserError, lambda: parser.parse_args(['--enum=' + str(val)]))

        cfg = parser.parse_args(['--enum=C'], with_meta=False)
        self.assertEqual('enum: C\n', parser.dump(cfg))

        help_str = StringIO()
        parser.print_help(help_str)
        self.assertIn('Description (type: MyEnum, default: C)',
                      help_str.getvalue())
コード例 #4
0
 def test_list_path(self):
     parser = ArgumentParser()
     parser.add_argument('--paths', type=List[Path_fc])
     cfg = parser.parse_args(['--paths=["file1", "file2"]'])
     self.assertEqual(['file1', 'file2'], cfg.paths)
     self.assertIsInstance(cfg.paths[0], Path)
     self.assertIsInstance(cfg.paths[1], Path)
コード例 #5
0
 def test_list(self):
     for list_type in [Iterable, List, Sequence]:
         with self.subTest(str(list_type)):
             parser = ArgumentParser()
             parser.add_argument('--list', type=list_type[int])
             cfg = parser.parse_args(['--list=[1, 2]'])
             self.assertEqual([1, 2], cfg.list)
コード例 #6
0
    def test_ActionJsonnet(self):
        parser = ArgumentParser(default_meta=False, error_handler=None)
        parser.add_argument('--input.ext_vars', action=ActionJsonnetExtVars())
        parser.add_argument('--input.jsonnet',
                            action=ActionJsonnet(
                                ext_vars='input.ext_vars',
                                schema=json.dumps(example_schema)))

        cfg2 = parser.parse_args([
            '--input.ext_vars', '{"param": 123}', '--input.jsonnet',
            example_2_jsonnet
        ])
        self.assertEqual(123, cfg2.input.jsonnet['param'])
        self.assertEqual(9, len(cfg2.input.jsonnet['records']))
        self.assertEqual('#8', cfg2.input.jsonnet['records'][-2]['ref'])
        self.assertEqual(15.5, cfg2.input.jsonnet['records'][-2]['val'])

        cfg1 = parser.parse_args(['--input.jsonnet', example_1_jsonnet])
        self.assertEqual(cfg1.input.jsonnet['records'],
                         cfg2.input.jsonnet['records'])

        self.assertRaises(
            ParserError, lambda: parser.parse_args([
                '--input.ext_vars', '{"param": "a"}', '--input.jsonnet',
                example_2_jsonnet
            ]))
        self.assertRaises(
            ParserError,
            lambda: parser.parse_args(['--input.jsonnet', example_2_jsonnet]))

        self.assertRaises(ValueError, lambda: ActionJsonnet(ext_vars=2))
        self.assertRaises(
            ValueError,
            lambda: ActionJsonnet(schema='.' + json.dumps(example_schema)))
コード例 #7
0
    def test_class_type_with_default_config_files(self):
        config = {
            'class_path': 'calendar.Calendar',
            'init_args': {
                'firstweekday': 3
            },
        }
        config_path = os.path.join(self.tmpdir, 'config.yaml')
        with open(config_path, 'w') as f:
            json.dump({'data': {'cal': config}}, f)

        class MyClass:
            def __init__(self, cal: Optional[Calendar] = None, val: int = 2):
                self.cal = cal

        parser = ArgumentParser(error_handler=None,
                                default_config_files=[config_path])
        parser.add_argument('--op', default='from default')
        parser.add_class_arguments(MyClass, 'data')

        cfg = parser.get_defaults()
        self.assertEqual(config_path, str(cfg['__default_config__']))
        self.assertEqual(cfg.data.cal.as_dict(), config)
        dump = parser.dump(cfg)
        self.assertIn('class_path: calendar.Calendar\n', dump)
        self.assertIn('firstweekday: 3\n', dump)

        cfg = parser.parse_args([])
        self.assertEqual(cfg.data.cal.as_dict(), config)
        cfg = parser.parse_args(['--data.cal.class_path=calendar.Calendar'],
                                defaults=False)
        self.assertEqual(cfg.data.cal,
                         Namespace(class_path='calendar.Calendar'))
コード例 #8
0
 def get_parser_lv2():
     parser_lv2 = ArgumentParser(description='parser_lv2 description')
     parser_lv2.add_argument('--a1', help='lv2_a1 help')
     group_lv2 = parser_lv2.add_argument_group(
         description='group_lv2 description')
     group_lv2.add_argument('--a2', help='lv2_a2 help')
     return parser_lv2
コード例 #9
0
    def test_ActionYesNo_parse_env(self):
        parser = example_parser()
        self.assertEqual(
            True,
            parser.parse_env({
                'APP_BOOLS__DEF_FALSE': 'true'
            }).bools.def_false)
        self.assertEqual(
            True,
            parser.parse_env({
                'APP_BOOLS__DEF_FALSE': 'yes'
            }).bools.def_false)
        self.assertEqual(
            False,
            parser.parse_env({
                'APP_BOOLS__DEF_TRUE': 'false'
            }).bools.def_true)
        self.assertEqual(
            False,
            parser.parse_env({
                'APP_BOOLS__DEF_TRUE': 'no'
            }).bools.def_true)

        parser = ArgumentParser(default_env=True, env_prefix='APP')
        parser.add_argument('--op', action=ActionYesNo, default=False)
        self.assertEqual(True, parser.parse_env({'APP_OP': 'true'}).op)
コード例 #10
0
    def get_config_parser():

        parser = ArgumentParser()
        parser.add_argument("--factor",
                            type=int,
                            default=2,
                            help="Factor to multiply")
        return parser
コード例 #11
0
 def test_typehint_serialize_list(self):
     parser = ArgumentParser()
     action = parser.add_argument('--list',
                                  type=Union[PositiveInt,
                                             List[PositiveInt]])
     self.assertEqual([1, 2],
                      action.serialize([PositiveInt(1),
                                        PositiveInt(2)]))
     self.assertRaises(ValueError, lambda: action.serialize([1, -2]))
コード例 #12
0
    def test_list_enum(self):
        class MyEnum(Enum):
            ab = 0
            xy = 1

        parser = ArgumentParser(error_handler=None)
        parser.add_argument('--list', type=List[MyEnum])
        self.assertEqual([MyEnum.xy, MyEnum.ab],
                         parser.parse_args(['--list=["xy", "ab"]']).list)
コード例 #13
0
    def test_typehint_serialize_enum(self):
        class MyEnum(Enum):
            a = 1
            b = 2

        parser = ArgumentParser()
        action = parser.add_argument('--enum', type=Optional[MyEnum])
        self.assertEqual('b', action.serialize(MyEnum.b))
        self.assertRaises(ValueError, lambda: action.serialize('x'))
コード例 #14
0
 def test_nested_tuples(self):
     parser = ArgumentParser(error_handler=None)
     parser.add_argument('--tuple',
                         type=Tuple[Tuple[str, str],
                                    Tuple[Tuple[int, float], Tuple[int,
                                                                   float]]])
     cfg = parser.parse_args(
         ['--tuple=[["foo", "bar"], [[1, 2.02], [3, 3.09]]]'])
     self.assertEqual((('foo', 'bar'), ((1, 2.02), (3, 3.09))), cfg.tuple)
コード例 #15
0
    def test_class_type_subclass_given_by_name_issue_84(self):
        class LocalCalendar(Calendar):
            pass

        parser = ArgumentParser()
        parser.add_argument('--op', type=Union[Calendar, GzipFile, None])
        cfg = parser.parse_args(['--op=TextCalendar'])
        self.assertEqual(cfg.op.class_path, 'calendar.TextCalendar')

        out = StringIO()
        parser.print_help(out)
        for class_path in [
                'calendar.Calendar', 'calendar.TextCalendar', 'gzip.GzipFile'
        ]:
            self.assertIn(class_path, out.getvalue())
        self.assertNotIn('LocalCalendar', out.getvalue())

        class HTMLCalendar(Calendar):
            pass

        with mock_module(HTMLCalendar) as module:
            err = StringIO()
            with redirect_stderr(err), self.assertRaises(SystemExit):
                parser.parse_args(['--op.help=HTMLCalendar'])
            self.assertIn('Give the full class path to avoid ambiguity',
                          err.getvalue())
            self.assertIn(f'{module}.HTMLCalendar', err.getvalue())
コード例 #16
0
 def test_nargs_questionmark(self):
     parser = ArgumentParser(error_handler=None)
     parser.add_argument('p1')
     parser.add_argument('p2', nargs='?', type=OpenUnitInterval)
     self.assertIsNone(parser.parse_args(['a']).p2)
     self.assertEqual(0.5, parser.parse_args(['a', '0.5']).p2)
     self.assertRaises(ParserError, lambda: parser.parse_args(['a', 'b']))
コード例 #17
0
ファイル: main.py プロジェクト: kshah1997/pplbench-1
def read_config(args: Optional[List[str]]) -> SimpleNamespace:
    """
    Parse command line arguments and return a JSON object.
    :returns: benchmark configuration.
    """
    parser = ArgumentParser()
    parser.add_argument("config",
                        action=ActionJsonSchema(schema=SCHEMA),
                        help="%s")
    config = parser.parse_args(args).config
    return config
コード例 #18
0
 def test_ActionParser_failures(self):
     parser_lv2 = ArgumentParser()
     parser_lv2.add_argument('--op')
     parser = ArgumentParser(error_handler=None)
     parser.add_argument('--inner', action=ActionParser(parser=parser_lv2))
     self.assertRaises(
         ValueError,
         lambda: parser.add_argument('--mistake',
                                     action=ActionParser(parser=parser)))
     self.assertRaises(ParserError,
                       lambda: parser.parse_args(['--inner=1']))
コード例 #19
0
    def test_parser_mode_subparsers(self):
        subparser = ArgumentParser()
        parser = ArgumentParser()
        subcommands = parser.add_subcommands()
        subcommands.add_subcommand('sub', subparser)

        with unittest.mock.patch.dict('jsonargparse.loaders_dumpers.loaders'):
            set_loader('custom', yaml.safe_load)
            parser.parser_mode = 'custom'
            self.assertEqual('custom', parser.parser_mode)
            self.assertEqual('custom', subparser.parser_mode)
コード例 #20
0
 def test_ActionYesNo_old_bool(self):
     parser = ArgumentParser(error_handler=None)
     parser.add_argument('--val',
                         nargs=1,
                         action=ActionYesNo(no_prefix=None))
     self.assertEqual(False, parser.get_defaults().val)
     self.assertEqual(True, parser.parse_args(['--val', 'true']).val)
     self.assertEqual(True, parser.parse_args(['--val', 'yes']).val)
     self.assertEqual(False, parser.parse_args(['--val', 'false']).val)
     self.assertEqual(False, parser.parse_args(['--val', 'no']).val)
     self.assertRaises(ParserError,
                       lambda: parser.parse_args(['--val', '1']))
コード例 #21
0
 def test_ActionParser_conflict(self):
     parser_lv2 = ArgumentParser()
     parser_lv2.add_argument('--op')
     parser = ArgumentParser(error_handler=None)
     parser.add_argument('--inner.op')
     self.assertRaises(
         ValueError,
         lambda: parser.add_argument('--inner',
                                     action=ActionParser(parser_lv2)))
コード例 #22
0
    def test_class_path_override_with_default_config_files(self):
        class MyCalendar(Calendar):
            def __init__(self, *args, param: str = '0', **kwargs):
                super().__init__(*args, **kwargs)

        with mock_module(MyCalendar) as module:
            config = {
                'class_path': f'{module}.MyCalendar',
                'init_args': {
                    'firstweekday': 2,
                    'param': '1'
                },
            }
            config_path = os.path.join(self.tmpdir, 'config.yaml')
            with open(config_path, 'w') as f:
                json.dump({'cal': config}, f)

            parser = ArgumentParser(error_handler=None,
                                    default_config_files=[config_path])
            parser.add_argument('--cal', type=Optional[Calendar])

            cfg = parser.instantiate_classes(parser.get_defaults())
            self.assertIsInstance(cfg['cal'], MyCalendar)

            cfg = parser.parse_args([
                '--cal={"class_path": "calendar.Calendar", "init_args": {"firstweekday": 3}}'
            ])
            self.assertEqual(type(parser.instantiate_classes(cfg)['cal']),
                             Calendar)
コード例 #23
0
    def test_class_type_subclass_nested_help(self):
        class Class:
            def __init__(self, cal: Calendar, p1: int = 0):
                self.cal = cal

        parser = ArgumentParser()
        parser.add_argument('--op', type=Class)

        for pattern in [r'[\s=]', r'\s']:
            with self.subTest('" "' if '=' in pattern else '"="'), mock_module(
                    Class) as module:
                out = StringIO()
                args = re.split(
                    pattern,
                    f'--op.help={module}.Class --op.init_args.cal.help=TextCalendar'
                )
                with redirect_stdout(out), self.assertRaises(SystemExit):
                    parser.parse_args(args)
                self.assertIn('--op.init_args.cal.init_args.firstweekday',
                              out.getvalue())

        with self.subTest('invalid'), mock_module(Class) as module:
            err = StringIO()
            with redirect_stderr(err), self.assertRaises(SystemExit):
                parser.parse_args(
                    [f'--op.help={module}.Class', '--op.init_args.p1=1'])
            self.assertIn('Expected a nested --*.help option', err.getvalue())
コード例 #24
0
    def test_parser_mode_omegaconf_in_subcommands(self):
        subparser = ArgumentParser()
        subparser.add_argument('--config', action=ActionConfigFile)
        subparser.add_argument('--source', type=str)
        subparser.add_argument('--target', type=str)

        parser = ArgumentParser(error_handler=None, parser_mode='omegaconf')
        subcommands = parser.add_subcommands()
        subcommands.add_subcommand('sub', subparser)

        config = {
            'source': 'hello',
            'target': '${source}',
        }
        cfg = parser.parse_args(['sub', f'--config={yaml_dump(config)}'])
        self.assertEqual(cfg.sub.target, 'hello')
コード例 #25
0
ファイル: main.py プロジェクト: feynmanliang/pplbench-1
def read_config(args: Optional[List[str]]) -> SimpleNamespace:
    """
    Parse command line arguments and return a JSON object.
    :returns: benchmark configuration.
    """
    parser = ArgumentParser()
    parser.add_argument("config",
                        action=ActionJsonSchema(schema=SCHEMA),
                        help="%s")
    config = parser.parse_args(args).config

    # default num_warmup to half of num_sample
    if not hasattr(config, "num_warmup"):
        config.num_warmup = config.iterations // 2

    return config
コード例 #26
0
def make_grouped_parser(arg_groups, check_required=True):
    """
    Builds an argument parser, with args divided over groups.

    :param arg_groups: a dict of dicts {group_name: args}. See ./args for format.
    :param check_required: Check for required arguments, defaults to True
    :return: argument parser containing the args from arg_groups.
    """
    #parser = ArgumentParser(formatter_class='default_argparse')
    parser = ArgumentParser()
    parser.add_argument('--hparams_file', action=ActionConfigFile)
    for group_name, arg_group in arg_groups.items():
        add_arg_group(parser,
                      group_name,
                      arg_group,
                      check_required=check_required)
    return parser
コード例 #27
0
    def __init__(self):

        # Initialise parser
        self.parser = ArgumentParser(
            description=
            "A spatial & temporal-based two-stream convolutional neural network for recognising great ape behaviour."
        )

        # Take in config
        self.parser.add_argument("--config", action=ActionConfigFile)

        # Add all arguments to parser
        self.add_general_arguments()
        self.add_hyperparameter_arguments()
        self.add_loss_arguments()
        self.add_lstm_arguments()
        self.add_dataset_arguments()
        self.add_dataloader_arguments()
        self.add_augmentation_arguments()
        self.add_path_arguments()
        self.add_frequency_arguments()

        # Create config
        self.config = self.parser.parse_args()

        if not self.config.name:
            print("Please specify model name in order to train/evaluate")
            exit()

        self.config.dataloader.worker_count = cpu_count()

        if self.config.mode == "test":
            self.config.dataloader.batch_size = 1
            self.config.dataloader.shuffle = False
            self.config.dataset.sequence_length = 20

        if self.config.bucket == "":
            self.config.bucket = None

        self.config.paths.annotations = str(self.config.paths.annotations)
        self.config.paths.checkpoints = str(self.config.paths.checkpoints)
        self.config.paths.classes = str(self.config.paths.classes)
        self.config.paths.frames = str(self.config.paths.frames)
        self.config.paths.logs = str(self.config.paths.logs)
        self.config.paths.output = str(self.config.paths.output)
        self.config.paths.splits = str(self.config.paths.splits)
コード例 #28
0
 def test_add_argument_type(self):
     FourDigits = restricted_string_type('FourDigits', '^[0-9]{4}$')
     parser = ArgumentParser(error_handler=None)
     parser.add_argument('--op', type=FourDigits)
     self.assertEqual('1234', parser.parse_args(['--op', '1234']).op)
     self.assertRaises(ParserError,
                       lambda: parser.parse_args(['--op', '123']))
     self.assertRaises(ParserError,
                       lambda: parser.parse_args(['--op', '12345']))
     self.assertRaises(ParserError,
                       lambda: parser.parse_args(['--op', 'abcd']))
コード例 #29
0
 def test_default_path_unregistered_type(self):
     parser = ArgumentParser()
     parser.add_argument('--path',
                         type=path_type('drw', skip_check=True),
                         default=Path('test', mode='drw', skip_check=True))
     cfg = parser.parse_args([])
     self.assertEqual('path: test\n', parser.dump(cfg))
     out = StringIO()
     parser.print_help(out)
     self.assertIn('(type: Path_drw_skip_check, default: test)',
                   out.getvalue())
コード例 #30
0
    def test_list_union(self):
        class MyEnum(Enum):
            ab = 1

        parser = ArgumentParser(error_handler=None)
        parser.add_argument('--list1',
                            type=List[Union[float, str, type(None)]])
        parser.add_argument('--list2', type=List[Union[int, MyEnum]])
        self.assertEqual([1.2, 'ab'],
                         parser.parse_args(['--list1=[1.2, "ab"]']).list1)
        self.assertEqual([3, MyEnum.ab],
                         parser.parse_args(['--list2=[3, "ab"]']).list2)
        self.assertRaises(
            ParserError,
            lambda: parser.parse_args(['--list1={"a":1, "b":"2"}']))