def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: r"""Scans the DataModule signature and returns argument names, types and default values. Returns: List with tuples of 3 values: (argument name, set with argument types, argument default value). """ return get_init_arguments_and_types(cls)
def test_get_init_arguments_and_types(): """Asserts a correctness of the `get_init_arguments_and_types` Trainer classmethod.""" args = argparse.get_init_arguments_and_types(Trainer) parameters = inspect.signature(Trainer).parameters assert len(parameters) == len(args) for arg in args: assert parameters[arg[0]].default == arg[2] kwargs = {arg[0]: arg[2] for arg in args} trainer = Trainer(**kwargs) assert isinstance(trainer, Trainer)
def insert_env_defaults(self, *args, **kwargs): cls = self.__class__ # get the class if args: # inace any args passed move them to kwargs # parse only the argument names cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)] # convert args to kwargs kwargs.update({k: v for k, v in zip(cls_arg_names, args)}) env_variables = vars(parse_env_variables(cls)) # update the kwargs by env variables kwargs = dict(list(env_variables.items()) + list(kwargs.items())) # all args were already moved to kwargs return fn(self, **kwargs)
def overwrite_by_env_vars(self, *args, **kwargs): # get the class cls = self.__class__ if args: # inace any args passed move them to kwargs # parse only the argument names cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)] # convert args to kwargs kwargs.update({k: v for k, v in zip(cls_arg_names, args)}) # update the kwargs by env variables # todo: maybe add a warning that some init args were overwritten by Env arguments kwargs.update(vars(parse_env_variables(cls))) # all args were already moved to kwargs return fn(self, **kwargs)