def test_get_init_arguments_and_types():
    """Asserts a correctness of the `get_init_arguments_and_types` Trainer classmethod."""
    args = argparse_utils.get_init_arguments_and_types(Trainer)
    parameters = inspect.signature(Trainer).parameters
    assert len(parameters) == len(args)
    for arg in args:
        assert parameters[arg[0]].default == arg[2]

    kwargs = {arg[0]: arg[2] for arg in args}
    trainer = Trainer(**kwargs)
    assert isinstance(trainer, Trainer)
    def overwrite_by_env_vars(self, *args, **kwargs):
        # get the class
        cls = self.__class__
        if args:  # inace any args passed move them to kwargs
            # parse only the argument names
            cls_arg_names = [
                arg[0] for arg in get_init_arguments_and_types(cls)
            ]
            # convert args to kwargs
            kwargs.update({k: v for k, v in zip(cls_arg_names, args)})
        # update the kwargs by env variables
        # todo: maybe add a warning that some init args were overwritten by Env arguments
        kwargs.update(vars(parse_env_variables(cls)))

        # all args were already moved to kwargs
        return fn(self, **kwargs)
Пример #3
0
    def add_argparse_args(cls,
                          parent_parser: ArgumentParser) -> ArgumentParser:
        r"""Extends existing argparse by default `Trainer` attributes.

        Args:
            parent_parser:
                The custom cli arguments parser, which will be extended by
                the Trainer default arguments.

        Only arguments of the allowed types (str, float, int, bool) will
        extend the `parent_parser`.

        Examples:
            >>> import argparse
            >>> import pprint
            >>> parser = argparse.ArgumentParser()
            >>> parser = Trainer.add_argparse_args(parser)
            >>> args = parser.parse_args([])
            >>> pprint.pprint(vars(args))  # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
            {...
             'check_val_every_n_epoch': 1,
             'checkpoint_callback': True,
             'default_root_dir': None,
             'deterministic': False,
             'distributed_backend': None,
             'early_stop_callback': False,
             ...
             'logger': True,
             'max_epochs': 1000,
             'max_steps': None,
             'min_epochs': 1,
             'min_steps': None,
             ...
             'profiler': None,
             'progress_bar_refresh_rate': 1,
             ...}

        """
        parser = ArgumentParser(
            parents=[parent_parser],
            add_help=False,
        )

        blacklist = ['kwargs']
        depr_arg_names = cls.get_deprecated_arg_names() + blacklist

        allowed_types = (str, int, float, bool)

        # TODO: get "help" from docstring :)
        for arg, arg_types, arg_default in (
                at for at in argparse_utils.get_init_arguments_and_types(cls)
                if at[0] not in depr_arg_names):
            arg_types = [at for at in allowed_types if at in arg_types]
            if not arg_types:
                # skip argument with not supported type
                continue
            arg_kwargs = {}
            if bool in arg_types:
                arg_kwargs.update(nargs="?", const=True)
                # if the only arg type is bool
                if len(arg_types) == 1:
                    use_type = parsing.str_to_bool
                # if only two args (str, bool)
                elif len(arg_types) == 2 and set(arg_types) == {str, bool}:
                    use_type = parsing.str_to_bool_or_str
                else:
                    # filter out the bool as we need to use more general
                    use_type = [at for at in arg_types if at is not bool][0]
            else:
                use_type = arg_types[0]

            if arg == 'gpus' or arg == 'tpu_cores':
                use_type = Trainer._gpus_allowed_type
                arg_default = Trainer._gpus_arg_default

            # hack for types in (int, float)
            if len(arg_types) == 2 and int in set(arg_types) and float in set(
                    arg_types):
                use_type = Trainer._int_or_float_type

            # hack for track_grad_norm
            if arg == 'track_grad_norm':
                use_type = float

            parser.add_argument(
                f'--{arg}',
                dest=arg,
                default=arg_default,
                type=use_type,
                help='autogenerated by pl.Trainer',
                **arg_kwargs,
            )

        return parser