def test_subcommand_with_subclass_default_override_lightning_issue_10859(
            self):
        class Arch:
            def __init__(self, a: int = 1):
                pass

        class ArchB(Arch):
            def __init__(self, a: int = 2, b: int = 3):
                pass

        class ArchC(Arch):
            def __init__(self, a: int = 4, c: int = 5):
                pass

        parser = ArgumentParser(error_handler=None)
        parser_subcommands = parser.add_subcommands()
        subparser = ArgumentParser()
        subparser.add_argument('--arch', type=Arch)

        with mock_module(Arch, ArchB, ArchC) as module:
            default = {'class_path': f'{module}.ArchB'}
            value = {
                'class_path': f'{module}.ArchC',
                'init_args': {
                    'a': 10,
                    'c': 11
                }
            }

            subparser.set_defaults(arch=default)
            parser_subcommands.add_subcommand('fit', subparser)

            cfg = parser.parse_args(['fit', f'--arch={json.dumps(value)}'])
            self.assertEqual(cfg.fit.arch.as_dict(), value)
예제 #2
0
def get_parser():
    """Returns the argument parser object for the command line tool."""
    ## validate parser ##
    parser_validate = ModuleArchitecture.get_config_parser()
    parser_validate.description = 'Command for checking the validity of neural network module architecture files.'
    parser_validate.set_defaults(propagators='default')
    parser_validate.add_argument(
        'jsonnet_paths',
        action=ActionPath(mode='fr'),
        nargs='+',
        help=
        'Path(s) to neural network module architecture file(s) in jsonnet narchi format.'
    )

    ## render parser ##
    parser_render = ModuleArchitectureRenderer.get_config_parser()
    parser_render.description = 'Command for rendering a neural network module architecture file.'
    parser_render.set_defaults(propagators='default')
    parser_render.add_argument(
        'jsonnet_path',
        action=ActionPath(mode='fr'),
        help=
        'Path to a neural network module architecture file in jsonnet narchi format.'
    )
    parser_render.add_argument(
        'out_file',
        nargs='?',
        action=ActionPath(mode='fc'),
        help=
        'Path where to write the architecture diagram (with a valid extension for pygraphviz draw). If '
        'unset a pdf is saved to the output directory.')

    ## schema parser ##
    parser_schema = ArgumentParser(
        description='Prints a schema as a pretty json.')
    parser_schema.add_argument(
        'schema',
        nargs='?',
        default='narchi',
        choices=[x for x in schemas.keys() if x is not None],
        help='Which of the available schemas to print.')

    ## global parser ##
    parser = ArgumentParser(description=__doc__, version=__version__)
    parser.add_argument(
        '--stack_trace',
        type=bool,
        default=False,
        help='Whether to print stack trace when there are errors.')
    parser.parser_validate = parser_validate
    parser.parser_render = parser_render
    parser.parser_schema = parser_schema

    subcommands = parser.add_subcommands()
    subcommands.add_subcommand('validate', parser_validate)
    subcommands.add_subcommand('render', parser_render)
    subcommands.add_subcommand('schema', parser_schema)

    return parser
예제 #3
0
    def test_parser_mode_subparsers(self):
        subparser = ArgumentParser()
        parser = ArgumentParser()
        subcommands = parser.add_subcommands()
        subcommands.add_subcommand('sub', subparser)

        with unittest.mock.patch.dict('jsonargparse.loaders_dumpers.loaders'):
            set_loader('custom', yaml.safe_load)
            parser.parser_mode = 'custom'
            self.assertEqual('custom', parser.parser_mode)
            self.assertEqual('custom', subparser.parser_mode)
예제 #4
0
    def test_parser_mode_omegaconf_in_subcommands(self):
        subparser = ArgumentParser()
        subparser.add_argument('--config', action=ActionConfigFile)
        subparser.add_argument('--source', type=str)
        subparser.add_argument('--target', type=str)

        parser = ArgumentParser(error_handler=None, parser_mode='omegaconf')
        subcommands = parser.add_subcommands()
        subcommands.add_subcommand('sub', subparser)

        config = {
            'source': 'hello',
            'target': '${source}',
        }
        cfg = parser.parse_args(['sub', f'--config={yaml_dump(config)}'])
        self.assertEqual(cfg.sub.target, 'hello')