Example #1
0
def inject_args(
    parser: argparse.ArgumentParser,
    hp_mgr: hpman.HyperParameterManager,
    *,
    inject_actions: List[str],
    action_prefix: str,
    serial_format: str,
    show_defaults: bool,
) -> argparse.ArgumentParser:
    """Inject hpman parsed hyperparameter settings into argparse arguments.
    Only a limited set of format are supported. See code for details.

    :param parser: Use given parser object of `class`:`argparse.ArgumentParser`.
    :param hp_mgr: A `class`:`hpman.HyperParameterManager` object.

    :param inject_actions: A list of actions names to inject
    :param action_prefix: Prefix for hpargparse related options
    :param serial_format: One of 'yaml' and 'pickle'
    :param show_defaults: Show default values

    :return: The injected parser.
    """

    help = ""
    if show_defaults:
        parser.formatter_class = argparse.ArgumentDefaultsHelpFormatter

        # Default value will be shown when using argparse.ArgumentDefaultsHelpFormatter
        # only if a help message is present. This is the behavior of argparse.
        help = " "

    value_names_been_set = set()

    def _make_value_names_been_set_injection(name, func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            value_names_been_set.add(name)
            return func(*args, **kwargs)

        return wrapper

    # add options for collected hyper-parameters
    for k, v in hp_mgr.get_values().items():
        # this is just a simple hack
        option_name = "--{}".format(k.replace("_", "-"))

        if _get_argument_type_by_value(v) == bool:
            # argparse does not directly support bool types.
            parser.add_argument(
                option_name,
                type=_make_value_names_been_set_injection(k, str2bool),
                default=v,
                help=help,
            )
        else:
            parser.add_argument(
                option_name,
                type=_make_value_names_been_set_injection(
                    k, _get_argument_type_by_value(v)
                ),
                default=v,
                help=help,
            )

    make_option = lambda name: "--{}-{}".format(action_prefix, name)

    for action in inject_actions:
        if action == "list":
            parser.add_argument(
                make_option("list"),
                action="store",
                default=None,
                const="yaml",
                nargs="?",
                choices=["detail", "yaml", "json"],
                help=(
                    "List all available hyperparameters. If `{} detail` is"
                    " specified, a verbose table will be print"
                ).format(make_option("list")),
            )
        elif action == "detail":
            parser.add_argument(
                make_option("detail"),
                action="store_true",
                help="Shorthand for --hp-list detail",
            )
        elif action == "save":
            parser.add_argument(
                make_option("save"),
                help=(
                    "Save hyperparameters to a file. The hyperparameters"
                    " are saved after processing of all other options"
                ),
            )

        elif action == "load":
            parser.add_argument(
                make_option("load"),
                help=(
                    "Load hyperparameters from a file. The hyperparameters"
                    " are loaded before any other options are processed"
                ),
            )

    if "load" in inject_actions or "save" in inject_actions:
        parser.add_argument(
            make_option("serial-format"),
            default=serial_format,
            choices=config.HP_SERIAL_FORMAT_CHOICES,
            help=(
                "Format of the saved config file. Defaults to {}."
                " It can be set to override auto file type deduction."
            ).format(serial_format),
        )

    if inject_actions:
        parser.add_argument(
            make_option("exit"),
            action="store_true",
            help="process all hpargparse actions and quit",
        )

    def __hpargparse_value_names_been_set(self):
        return value_names_been_set

    parser.__hpargparse_value_names_been_set = MethodType(
        __hpargparse_value_names_been_set, parser
    )

    return parser