def test_enum(self):
        class MyEnum(Enum):
            A = 1
            B = 2
            C = 3

        parser = ArgumentParser(error_handler=None)
        parser.add_argument('--enum',
                            type=MyEnum,
                            default=MyEnum.C,
                            help='Description')

        for val in ['A', 'B', 'C']:
            self.assertEqual(MyEnum[val],
                             parser.parse_args(['--enum=' + val]).enum)
        for val in ['X', 'b', 2]:
            self.assertRaises(
                ParserError, lambda: parser.parse_args(['--enum=' + str(val)]))

        cfg = parser.parse_args(['--enum=C'], with_meta=False)
        self.assertEqual('enum: C\n', parser.dump(cfg))

        help_str = StringIO()
        parser.print_help(help_str)
        self.assertIn('Description (type: MyEnum, default: C)',
                      help_str.getvalue())
    def test_class_type_subclass_given_by_name_issue_84(self):
        class LocalCalendar(Calendar):
            pass

        parser = ArgumentParser()
        parser.add_argument('--op', type=Union[Calendar, GzipFile, None])
        cfg = parser.parse_args(['--op=TextCalendar'])
        self.assertEqual(cfg.op.class_path, 'calendar.TextCalendar')

        out = StringIO()
        parser.print_help(out)
        for class_path in [
                'calendar.Calendar', 'calendar.TextCalendar', 'gzip.GzipFile'
        ]:
            self.assertIn(class_path, out.getvalue())
        self.assertNotIn('LocalCalendar', out.getvalue())

        class HTMLCalendar(Calendar):
            pass

        with mock_module(HTMLCalendar) as module:
            err = StringIO()
            with redirect_stderr(err), self.assertRaises(SystemExit):
                parser.parse_args(['--op.help=HTMLCalendar'])
            self.assertIn('Give the full class path to avoid ambiguity',
                          err.getvalue())
            self.assertIn(f'{module}.HTMLCalendar', err.getvalue())
 def test_nargs_questionmark(self):
     parser = ArgumentParser(error_handler=None)
     parser.add_argument('p1')
     parser.add_argument('p2', nargs='?', type=OpenUnitInterval)
     self.assertIsNone(parser.parse_args(['a']).p2)
     self.assertEqual(0.5, parser.parse_args(['a', '0.5']).p2)
     self.assertRaises(ParserError, lambda: parser.parse_args(['a', 'b']))
    def test_dict_union(self):
        class MyEnum(Enum):
            ab = 1

        parser = ArgumentParser(error_handler=None)
        parser.add_argument('--dict1',
                            type=Dict[int, Optional[Union[float, MyEnum]]])
        parser.add_argument('--dict2', type=Dict[str, Union[bool, Path_fc]])
        cfg = parser.parse_args(
            ['--dict1={"2":4.5, "6":"ab"}', '--dict2={"a":true, "b":"f"}'])
        self.assertEqual({2: 4.5, 6: MyEnum.ab}, cfg['dict1'])
        self.assertEqual({'a': True, 'b': 'f'}, cfg['dict2'])
        self.assertIsInstance(cfg['dict2']['b'], Path)
        self.assertEqual({5: None},
                         parser.parse_args(['--dict1={"5":null}'])['dict1'])
        self.assertRaises(ParserError,
                          lambda: parser.parse_args(['--dict1=["a", "b"]']))
        cfg = yaml.safe_load(parser.dump(cfg))
        self.assertEqual(
            {
                'dict1': {
                    '2': 4.5,
                    '6': 'ab'
                },
                'dict2': {
                    'a': True,
                    'b': 'f'
                }
            }, cfg)
    def test_class_type_with_default_config_files(self):
        config = {
            'class_path': 'calendar.Calendar',
            'init_args': {
                'firstweekday': 3
            },
        }
        config_path = os.path.join(self.tmpdir, 'config.yaml')
        with open(config_path, 'w') as f:
            json.dump({'data': {'cal': config}}, f)

        class MyClass:
            def __init__(self, cal: Optional[Calendar] = None, val: int = 2):
                self.cal = cal

        parser = ArgumentParser(error_handler=None,
                                default_config_files=[config_path])
        parser.add_argument('--op', default='from default')
        parser.add_class_arguments(MyClass, 'data')

        cfg = parser.get_defaults()
        self.assertEqual(config_path, str(cfg['__default_config__']))
        self.assertEqual(cfg.data.cal.as_dict(), config)
        dump = parser.dump(cfg)
        self.assertIn('class_path: calendar.Calendar\n', dump)
        self.assertIn('firstweekday: 3\n', dump)

        cfg = parser.parse_args([])
        self.assertEqual(cfg.data.cal.as_dict(), config)
        cfg = parser.parse_args(['--data.cal.class_path=calendar.Calendar'],
                                defaults=False)
        self.assertEqual(cfg.data.cal,
                         Namespace(class_path='calendar.Calendar'))
Example #6
0
    def test_ActionJsonnet_save(self):
        parser = ArgumentParser(error_handler=None)
        parser.add_argument('--ext_vars', action=ActionJsonnetExtVars())
        parser.add_argument('--jsonnet',
                            action=ActionJsonnet(ext_vars='ext_vars'))
        parser.add_argument('--cfg', action=ActionConfigFile)

        jsonnet_file = os.path.join(self.tmpdir, 'example.jsonnet')
        with open(jsonnet_file, 'w') as output_file:
            output_file.write(example_2_jsonnet)
        outdir = os.path.join(self.tmpdir, 'output')
        outyaml = os.path.join(outdir, 'main.yaml')
        outjsonnet = os.path.join(outdir, 'example.jsonnet')
        os.mkdir(outdir)

        cfg = parser.parse_args(
            ['--ext_vars', '{"param": 123}', '--jsonnet', jsonnet_file])
        self.assertEqual(str(cfg.jsonnet['__path__']), jsonnet_file)

        parser.save(cfg, outyaml)
        cfg2 = parser.parse_args(['--cfg', outyaml])
        cfg2.cfg = None
        self.assertTrue(os.path.isfile(outyaml))
        self.assertTrue(os.path.isfile(outjsonnet))
        self.assertEqual(strip_meta(cfg), strip_meta(cfg2))

        os.unlink(outyaml)
        os.unlink(outjsonnet)
        parser.save(strip_meta(cfg), outyaml)
        cfg3 = parser.parse_args(['--cfg', outyaml])
        cfg3.cfg = None
        self.assertTrue(os.path.isfile(outyaml))
        self.assertTrue(not os.path.isfile(outjsonnet))
        self.assertEqual(strip_meta(cfg), strip_meta(cfg3))
    def test_class_type_subclass_nested_help(self):
        class Class:
            def __init__(self, cal: Calendar, p1: int = 0):
                self.cal = cal

        parser = ArgumentParser()
        parser.add_argument('--op', type=Class)

        for pattern in [r'[\s=]', r'\s']:
            with self.subTest('" "' if '=' in pattern else '"="'), mock_module(
                    Class) as module:
                out = StringIO()
                args = re.split(
                    pattern,
                    f'--op.help={module}.Class --op.init_args.cal.help=TextCalendar'
                )
                with redirect_stdout(out), self.assertRaises(SystemExit):
                    parser.parse_args(args)
                self.assertIn('--op.init_args.cal.init_args.firstweekday',
                              out.getvalue())

        with self.subTest('invalid'), mock_module(Class) as module:
            err = StringIO()
            with redirect_stderr(err), self.assertRaises(SystemExit):
                parser.parse_args(
                    [f'--op.help={module}.Class', '--op.init_args.p1=1'])
            self.assertIn('Expected a nested --*.help option', err.getvalue())
Example #8
0
    def test_ActionJsonnet(self):
        parser = ArgumentParser(default_meta=False, error_handler=None)
        parser.add_argument('--input.ext_vars', action=ActionJsonnetExtVars())
        parser.add_argument('--input.jsonnet',
                            action=ActionJsonnet(
                                ext_vars='input.ext_vars',
                                schema=json.dumps(example_schema)))

        cfg2 = parser.parse_args([
            '--input.ext_vars', '{"param": 123}', '--input.jsonnet',
            example_2_jsonnet
        ])
        self.assertEqual(123, cfg2.input.jsonnet['param'])
        self.assertEqual(9, len(cfg2.input.jsonnet['records']))
        self.assertEqual('#8', cfg2.input.jsonnet['records'][-2]['ref'])
        self.assertEqual(15.5, cfg2.input.jsonnet['records'][-2]['val'])

        cfg1 = parser.parse_args(['--input.jsonnet', example_1_jsonnet])
        self.assertEqual(cfg1.input.jsonnet['records'],
                         cfg2.input.jsonnet['records'])

        self.assertRaises(
            ParserError, lambda: parser.parse_args([
                '--input.ext_vars', '{"param": "a"}', '--input.jsonnet',
                example_2_jsonnet
            ]))
        self.assertRaises(
            ParserError,
            lambda: parser.parse_args(['--input.jsonnet', example_2_jsonnet]))

        self.assertRaises(ValueError, lambda: ActionJsonnet(ext_vars=2))
        self.assertRaises(
            ValueError,
            lambda: ActionJsonnet(schema='.' + json.dumps(example_schema)))
 def test_ActionParser_required(self):
     p1 = ArgumentParser()
     p1.add_argument('--op1', required=True)
     p2 = ArgumentParser(error_handler=None)
     p2.add_argument('--op2', action=ActionParser(parser=p1))
     p2.parse_args(['--op2.op1=1'])
     self.assertRaises(ParserError, lambda: p2.parse_args([]))
 def test_nested_mapping_without_args(self):
     parser = ArgumentParser()
     parser.add_argument('--map', type=Mapping[str, Union[int, Mapping]])
     self.assertEqual(parser.parse_args(['--map={"a": 1}']).map, {"a": 1})
     self.assertEqual(
         parser.parse_args(['--map={"b": {"c": 2}}']).map, {"b": {
             "c": 2
         }})
 def _test_typehint_parameterized_types(self, type):
     parser = ArgumentParser(error_handler=None)
     ActionTypeHint.is_supported_typehint(type, full=True)
     parser.add_argument('--cal', type=type[Calendar])
     cfg = parser.parse_args(['--cal=calendar.Calendar'])
     self.assertEqual(cfg.cal, Calendar)
     self.assertEqual(parser.dump(cfg), 'cal: calendar.Calendar\n')
     self.assertRaises(ParserError,
                       lambda: parser.parse_args(['--cal=uuid.UUID']))
 def test_no_str_strip(self):
     parser = ArgumentParser(error_handler=None)
     parser.add_argument('--op', type=Optional[str])
     parser.add_argument('--cfg', action=ActionConfigFile)
     self.assertEqual('  ', parser.parse_args(['--op', '  ']).op)
     self.assertEqual('', parser.parse_args(['--op', '']).op)
     self.assertEqual(' abc ', parser.parse_args(['--op= abc ']).op)
     self.assertEqual(' ', parser.parse_args(['--cfg={"op":" "}']).op)
     self.assertIsNone(parser.parse_args(['--op=null']).op)
Example #13
0
    def test_set_loader_safe_load_invalid_scientific_notation(self):
        parser = ArgumentParser(error_handler=None)
        parser.add_argument('--num', type=float)

        with unittest.mock.patch.dict('jsonargparse.loaders_dumpers.loaders'):
            set_loader('yaml', yaml.safe_load)
            self.assertRaises(ParserError,
                              lambda: parser.parse_args(['--num=1e-3']))

        self.assertEqual(1e-3, parser.parse_args(['--num=1e-3']).num)
 def test_optional_path(self):
     pathlib.Path('file_fr').touch()
     parser = ArgumentParser(error_handler=None)
     parser.add_argument('--path', type=Optional[Path_fr])
     self.assertIsNone(parser.parse_args(['--path=null']).path)
     cfg = parser.parse_args(['--path=file_fr'])
     self.assertEqual('file_fr', cfg.path)
     self.assertIsInstance(cfg.path, Path)
     self.assertRaises(ParserError,
                       lambda: parser.parse_args(['--path=not_exist']))
    def test_ActionYesNo(self):
        parser = example_parser()
        defaults = parser.get_defaults()
        self.assertEqual(False, defaults.bools.def_false)
        self.assertEqual(True, defaults.bools.def_true)
        self.assertEqual(
            True,
            parser.parse_args(['--bools.def_false']).bools.def_false)
        self.assertEqual(
            False,
            parser.parse_args(['--no_bools.def_false']).bools.def_false)
        self.assertEqual(
            True,
            parser.parse_args(['--bools.def_true']).bools.def_true)
        self.assertEqual(
            False,
            parser.parse_args(['--no_bools.def_true']).bools.def_true)
        self.assertEqual(
            True,
            parser.parse_args(['--bools.def_false=true']).bools.def_false)
        self.assertEqual(
            False,
            parser.parse_args(['--bools.def_false=false']).bools.def_false)
        self.assertEqual(
            True,
            parser.parse_args(['--bools.def_false=yes']).bools.def_false)
        self.assertEqual(
            False,
            parser.parse_args(['--bools.def_false=no']).bools.def_false)
        self.assertEqual(
            True,
            parser.parse_args(['--no_bools.def_true=no']).bools.def_true)
        self.assertRaises(ParserError,
                          lambda: parser.parse_args(['--bools.def_true nope']))

        parser = ArgumentParser()
        parser.add_argument('--val', action=ActionYesNo)
        self.assertEqual(True, parser.parse_args(['--val']).val)
        self.assertEqual(False, parser.parse_args(['--no_val']).val)
        parser = ArgumentParser()
        parser.add_argument('--with-val',
                            action=ActionYesNo(yes_prefix='with-',
                                               no_prefix='without-'))
        self.assertEqual(True, parser.parse_args(['--with-val']).with_val)
        self.assertEqual(False, parser.parse_args(['--without-val']).with_val)
        parser = ArgumentParser()
        self.assertRaises(
            ValueError,
            lambda: parser.add_argument('--val',
                                        action=ActionYesNo(yes_prefix='yes_')))
        self.assertRaises(
            ValueError, lambda: parser.add_argument('pos', action=ActionYesNo))
        self.assertRaises(
            ValueError, lambda: parser.add_argument(
                '--val', nargs='?', action=ActionYesNo(no_prefix=None)))
 def test_dict(self):
     parser = ArgumentParser(error_handler=None)
     parser.add_argument('--dict', type=dict)
     self.assertEqual({}, parser.parse_args(['--dict={}'])['dict'])
     self.assertEqual({
         'a': 1,
         'b': '2'
     },
                      parser.parse_args(['--dict={"a":1, "b":"2"}'
                                         ])['dict'])
     self.assertRaises(ParserError, lambda: parser.parse_args(['--dict=1']))
 def test_add_argument_type(self):
     FourDigits = restricted_string_type('FourDigits', '^[0-9]{4}$')
     parser = ArgumentParser(error_handler=None)
     parser.add_argument('--op', type=FourDigits)
     self.assertEqual('1234', parser.parse_args(['--op', '1234']).op)
     self.assertRaises(ParserError,
                       lambda: parser.parse_args(['--op', '123']))
     self.assertRaises(ParserError,
                       lambda: parser.parse_args(['--op', '12345']))
     self.assertRaises(ParserError,
                       lambda: parser.parse_args(['--op', 'abcd']))
 def test_ActionYesNo_old_bool(self):
     parser = ArgumentParser(error_handler=None)
     parser.add_argument('--val',
                         nargs=1,
                         action=ActionYesNo(no_prefix=None))
     self.assertEqual(False, parser.get_defaults().val)
     self.assertEqual(True, parser.parse_args(['--val', 'true']).val)
     self.assertEqual(True, parser.parse_args(['--val', 'yes']).val)
     self.assertEqual(False, parser.parse_args(['--val', 'false']).val)
     self.assertEqual(False, parser.parse_args(['--val', 'no']).val)
     self.assertRaises(ParserError,
                       lambda: parser.parse_args(['--val', '1']))
    def test_bool(self):
        parser = ArgumentParser(prog='app',
                                default_env=True,
                                error_handler=None)
        parser.add_argument('--val', type=bool)
        self.assertEqual(None, parser.get_defaults().val)
        self.assertEqual(True, parser.parse_args(['--val', 'true']).val)
        self.assertEqual(True, parser.parse_args(['--val', 'TRUE']).val)
        self.assertEqual(False, parser.parse_args(['--val', 'false']).val)
        self.assertEqual(False, parser.parse_args(['--val', 'FALSE']).val)
        self.assertRaises(ParserError,
                          lambda: parser.parse_args(['--val', '1']))

        os.environ['APP_VAL'] = 'true'
        self.assertEqual(True, parser.parse_args([]).val)
        os.environ['APP_VAL'] = 'True'
        self.assertEqual(True, parser.parse_args([]).val)
        os.environ['APP_VAL'] = 'false'
        self.assertEqual(False, parser.parse_args([]).val)
        os.environ['APP_VAL'] = 'False'
        self.assertEqual(False, parser.parse_args([]).val)
        os.environ['APP_VAL'] = '2'
        self.assertRaises(ParserError,
                          lambda: parser.parse_args(['--val', 'a']))
        del os.environ['APP_VAL']
    def test_list_union(self):
        class MyEnum(Enum):
            ab = 1

        parser = ArgumentParser(error_handler=None)
        parser.add_argument('--list1',
                            type=List[Union[float, str, type(None)]])
        parser.add_argument('--list2', type=List[Union[int, MyEnum]])
        self.assertEqual([1.2, 'ab'],
                         parser.parse_args(['--list1=[1.2, "ab"]']).list1)
        self.assertEqual([3, MyEnum.ab],
                         parser.parse_args(['--list2=[3, "ab"]']).list2)
        self.assertRaises(
            ParserError,
            lambda: parser.parse_args(['--list1={"a":1, "b":"2"}']))
    def test_ActionJsonSchema(self):
        parser = ArgumentParser(prog='app', default_meta=False, error_handler=None)
        parser.add_argument('--op1',
            action=ActionJsonSchema(schema=schema1))
        parser.add_argument('--op2',
            action=ActionJsonSchema(schema=schema2))
        parser.add_argument('--op3',
            action=ActionJsonSchema(schema=schema3))
        parser.add_argument('--cfg',
            action=ActionConfigFile)

        op1_val = [1, 2, 3, 4]
        op2_val = {'k1': 'one', 'k2': 2, 'k3': 3.3}

        self.assertEqual(op1_val, parser.parse_args(['--op1', str(op1_val)]).op1)
        self.assertRaises(ParserError, lambda: parser.parse_args(['--op1', '[1, "two"]']))
        self.assertRaises(ParserError, lambda: parser.parse_args(['--op1', '[1.5, 2]']))

        self.assertEqual(op2_val, parser.parse_args(['--op2', str(op2_val)]).op2)
        self.assertEqual(17, parser.parse_args(['--op2', '{"k2": 2}']).op2['k3'])
        self.assertRaises(ParserError, lambda: parser.parse_args(['--op2', '{"k1": 1}']))
        self.assertRaises(ParserError, lambda: parser.parse_args(['--op2', '{"k2": "2"}']))
        self.assertRaises(ParserError, lambda: parser.parse_args(['--op2', '{"k4": 4}']))

        op1_file = os.path.join(self.tmpdir, 'op1.json')
        op2_file = os.path.join(self.tmpdir, 'op2.json')
        cfg1_file = os.path.join(self.tmpdir, 'cfg1.yaml')
        cfg3_file = os.path.join(self.tmpdir, 'cfg3.yaml')
        cfg2_str = 'op1:\n  '+str(op1_val)+'\nop2:\n  '+str(op2_val)+'\n'
        with open(op1_file, 'w') as f:
            f.write(str(op1_val))
        with open(op2_file, 'w') as f:
            f.write(str(op2_val))
        with open(cfg1_file, 'w') as f:
            f.write('op1:\n  '+op1_file+'\nop2:\n  '+op2_file+'\n')
        with open(cfg3_file, 'w') as f:
            f.write('op3:\n  n1:\n  - '+str(op2_val)+'\n')

        cfg = parser.parse_path(cfg1_file)
        self.assertEqual(op1_val, cfg['op1'])
        self.assertEqual(op2_val, cfg['op2'])

        cfg = parser.parse_string(cfg2_str)
        self.assertEqual(op1_val, cfg['op1'])
        self.assertEqual(op2_val, cfg['op2'])

        cfg = parser.parse_args(['--cfg', cfg3_file])
        self.assertEqual(op2_val, cfg.op3['n1'][0])
        parser.check_config(cfg, skip_none=True)

        if os.name == 'posix' and platform.python_implementation() == 'CPython':
            os.chmod(op1_file, 0)
            self.assertRaises(ParserError, lambda: parser.parse_path(cfg1_file))
    def test_ActionParser_nested_dash_names(self):
        p1 = ArgumentParser(error_handler=None)
        p1.add_argument('--op1-like')

        p2 = ArgumentParser(error_handler=None)
        p2.add_argument('--op2-like', action=ActionParser(parser=p1))

        self.assertEqual(
            p2.parse_args(['--op2-like.op1-like=a']).op2_like.op1_like, 'a')

        p3 = ArgumentParser(error_handler=None)
        p3.add_argument('--op3', action=ActionParser(parser=p2))

        self.assertEqual(
            p3.parse_args(['--op3.op2-like.op1-like=b']).op3.op2_like.op1_like,
            'b')
Example #23
0
def parse_args():
    """Parse command-line arguments."""
    parser = ArgumentParser()
    parser.add_argument("data_dir", type=str)
    parser.add_argument("--n_workers", type=int, default=8)
    parser.add_argument("--save_dir", type=str, default=".")
    parser.add_argument("--comment", type=str)

    parser.add_argument("--frames_per_sample", type=int, default=40)
    parser.add_argument("--frames_per_slice", type=int, default=8)
    parser.add_argument("--bits", type=int, default=9)
    parser.add_argument("--conditioning_channels", type=int, default=128)
    parser.add_argument("--embedding_dim", type=int, default=256)
    parser.add_argument("--rnn_channels", type=int, default=896)
    parser.add_argument("--fc_channels", type=int, default=512)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--n_steps", type=int, default=100000)
    parser.add_argument("--valid_every", type=int, default=1000)
    parser.add_argument("--valid_ratio", type=float, default=0.1)
    parser.add_argument("--save_every", type=int, default=10000)
    parser.add_argument("--learning_rate", type=float, default=4e-4)
    parser.add_argument("--decay_every", type=int, default=20000)
    parser.add_argument("--decay_gamma", type=float, default=0.5)
    parser.add_argument("--training_config", action=ActionConfigFile)

    return parser.parse_args()
 def test_list(self):
     for list_type in [Iterable, List, Sequence]:
         with self.subTest(str(list_type)):
             parser = ArgumentParser()
             parser.add_argument('--list', type=list_type[int])
             cfg = parser.parse_args(['--list=[1, 2]'])
             self.assertEqual([1, 2], cfg.list)
    def test_class_path_override_with_default_config_files(self):
        class MyCalendar(Calendar):
            def __init__(self, *args, param: str = '0', **kwargs):
                super().__init__(*args, **kwargs)

        with mock_module(MyCalendar) as module:
            config = {
                'class_path': f'{module}.MyCalendar',
                'init_args': {
                    'firstweekday': 2,
                    'param': '1'
                },
            }
            config_path = os.path.join(self.tmpdir, 'config.yaml')
            with open(config_path, 'w') as f:
                json.dump({'cal': config}, f)

            parser = ArgumentParser(error_handler=None,
                                    default_config_files=[config_path])
            parser.add_argument('--cal', type=Optional[Calendar])

            cfg = parser.instantiate_classes(parser.get_defaults())
            self.assertIsInstance(cfg['cal'], MyCalendar)

            cfg = parser.parse_args([
                '--cal={"class_path": "calendar.Calendar", "init_args": {"firstweekday": 3}}'
            ])
            self.assertEqual(type(parser.instantiate_classes(cfg)['cal']),
                             Calendar)
    def test_subcommand_with_subclass_default_override_lightning_issue_10859(
            self):
        class Arch:
            def __init__(self, a: int = 1):
                pass

        class ArchB(Arch):
            def __init__(self, a: int = 2, b: int = 3):
                pass

        class ArchC(Arch):
            def __init__(self, a: int = 4, c: int = 5):
                pass

        parser = ArgumentParser(error_handler=None)
        parser_subcommands = parser.add_subcommands()
        subparser = ArgumentParser()
        subparser.add_argument('--arch', type=Arch)

        with mock_module(Arch, ArchB, ArchC) as module:
            default = {'class_path': f'{module}.ArchB'}
            value = {
                'class_path': f'{module}.ArchC',
                'init_args': {
                    'a': 10,
                    'c': 11
                }
            }

            subparser.set_defaults(arch=default)
            parser_subcommands.add_subcommand('fit', subparser)

            cfg = parser.parse_args(['fit', f'--arch={json.dumps(value)}'])
            self.assertEqual(cfg.fit.arch.as_dict(), value)
 def _test_typehint_non_parameterized_types(self, type):
     parser = ArgumentParser(error_handler=None)
     ActionTypeHint.is_supported_typehint(type, full=True)
     parser.add_argument('--type', type=type)
     cfg = parser.parse_args(['--type=uuid.UUID'])
     self.assertEqual(cfg.type, uuid.UUID)
     self.assertEqual(parser.dump(cfg), 'type: uuid.UUID\n')
 def test_list_path(self):
     parser = ArgumentParser()
     parser.add_argument('--paths', type=List[Path_fc])
     cfg = parser.parse_args(['--paths=["file1", "file2"]'])
     self.assertEqual(['file1', 'file2'], cfg.paths)
     self.assertIsInstance(cfg.paths[0], Path)
     self.assertIsInstance(cfg.paths[1], Path)
Example #29
0
def parse_args():
    """Parse command-line arguments."""
    parser = ArgumentParser()
    parser.add_argument("data_dir", type=str)
    parser.add_argument("--save_dir", type=str, default=".")
    parser.add_argument("--total_steps", type=int, default=60000)
    parser.add_argument("--warmup_steps", type=int, default=500)
    parser.add_argument("--valid_steps", type=int, default=1000)
    parser.add_argument("--log_steps", type=int, default=100)
    parser.add_argument("--save_steps", type=int, default=5000)
    parser.add_argument("--milestones",
                        type=int,
                        nargs=2,
                        default=[15000, 30000])
    parser.add_argument("--exclusive_rate", type=float, default=1.0)
    parser.add_argument("--n_samples", type=int, default=10)
    parser.add_argument("--accu_steps", type=int, default=2)
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--n_workers", type=int, default=8)
    parser.add_argument("--preload", action="store_true")
    parser.add_argument("--comment", type=str)
    parser.add_argument("--ckpt", type=str, default=None)
    parser.add_argument("--grad_norm_clip", type=float, default=10.0)
    parser.add_argument("--use_target_features", action='store_true')
    parser.add_argument("--train_config", action=ActionConfigFile)
    return vars(parser.parse_args())
    def test_typed_Callable_with_function_path(self):
        def my_func_1(p: int) -> str:
            return str(p)

        def my_func_2(p: str) -> int:
            return int(p)

        parser = ArgumentParser(error_handler=None)
        parser.add_argument('--callable', type=Callable[[int], str])

        with mock_module(my_func_1, my_func_2) as module:
            cfg = parser.parse_args([f'--callable={module}.my_func_1'])
            self.assertEqual(my_func_1, cfg.callable)
            cfg = parser.parse_args([f'--callable={module}.my_func_2'])
            self.assertEqual(
                my_func_2,
                cfg.callable)  # Currently callable types are ignored