Пример #1
0
def bind(
    parser: argparse.ArgumentParser,
    hp_mgr: hpman.HyperParameterManager,
    *,
    inject_actions: Union[bool, List[str]] = True,
    action_prefix: str = config.HP_ACTION_PREFIX_DEFAULT,
    serial_format: str = config.HP_SERIAL_FORMAT_DEFAULT,
    show_defaults: bool = True,
):
    """Bridging the gap between argparse and hpman. This is
        the most important method. Once bounded, hpargparse
        will do the rest for you.

    :param parser: A `class`:`argparse.ArgumentParser` object
    :param hp_mgr: The hyperparameter manager from `hpman`. It is
        usually an 'underscore' variable obtained by `from hpman.m import _`
    :param inject_actions: A list of actions names to inject, or True, to
        inject all available actions. Available actions are 'save', 'load',
        'detail' and 'list'
    :param action_prefix: Prefix for options of hpargparse injected additional
        actions. e.g., the default action_prefix is 'hp'. Therefore, the
        command line options added by :func:`.bind` will be '--hp-save',
        '--hp-load', '--hp-list', etc.
    :param serial_format: One of 'auto', 'yaml' and 'pickle'. Defaults to
        'auto'.  In most cases you need not to alter this argument as long as
        you give the right file extension when using save and load action. To
        be specific, '.yaml' and '.yml' would be deemed as yaml format, and
        '.pickle' and '.pkl' would be seen as pickle format.
    :param show_defaults: Show the default value in help messages.

    :note: pickle is done by `dill` to support pickling of more types.
    """

    # make action list to be injected
    inject_actions = parse_action_list(inject_actions)

    args_set_getter = inject_args(
        parser,
        hp_mgr,
        inject_actions=inject_actions,
        action_prefix=action_prefix,
        serial_format=serial_format,
        show_defaults=show_defaults,
    )

    # hook parser.parse_known_args
    parser._original_parse_known_args = parser.parse_known_args

    def new_parse_known_args(self, *args, **kwargs):
        args, extras = self._original_parse_known_args(*args, **kwargs)

        get_action_value = lambda name: getattr(
            args, "{}_{}".format(action_prefix, name))

        # load saved hyperparameter instance
        load_value = get_action_value("load")
        if "load" in inject_actions and load_value is not None:
            hp_load(load_value, hp_mgr, serial_format)

        # set hyperparameters set from command lines
        old_values = hp_mgr.get_values()
        for k in self.__hpargparse_value_names_been_set():
            v = old_values[k]
            assert hasattr(args, k)
            t = getattr(args, k)
            if isinstance(t, StringAsDefault):
                t = str(t)
            hp_mgr.set_value(k, t)

        save_value = get_action_value("save")
        if "save" in inject_actions and save_value is not None:
            hp_save(save_value, hp_mgr, serial_format)

        hp_list_value = get_action_value("list")
        if "list" in inject_actions and hp_list_value is not None:
            if hp_list_value == "yaml":
                syntax = Syntax(
                    yaml.dump(hp_mgr.get_values()).replace("\n\n", "\n"),
                    "yaml",
                    theme="monokai",
                )
                console = Console()
                console.print(syntax)
            elif hp_list_value == "json":
                syntax = Syntax(json.dumps(hp_mgr.get_values()),
                                "json",
                                theme="monokai")
                console = Console()
                console.print(syntax)
            else:
                assert hp_list_value == "detail", hp_list_value
                hp_list(hp_mgr)

            sys.exit(0)

        hp_detail_value = get_action_value("detail")
        if "detail" in inject_actions and hp_detail_value:
            hp_list(hp_mgr)
            sys.exit(0)

        if inject_actions and get_action_value("exit"):
            sys.exit(0)

        return args, extras

    parser.parse_known_args = MethodType(new_parse_known_args, parser)