示例#1
0
def check_nonempty(cfg: Config, fields: Sequence[Union[str, Sequence[str]]]):
    errors = []
    for key in fields:
        if isinstance(key, str) and (not get_by_dotkey(cfg, key)):
            errors.append(f"{key} is required.")
        elif isinstance(key, list) and (not any(get_by_dotkey(cfg, k) for k in key)):
            errors.append(f"Any of {', '.join(key)} is required.")
    if errors:
        raise ValueError("\n".join(errors))
示例#2
0
def resolve_path(cfg: Config) -> Config:
    for key in PATH_FIELDS:
        path = get_by_dotkey(cfg, key)
        if path:
            path = convert_fullpath_if_path(path)
            cfg = OmegaConf.merge(
                cfg, OmegaConf.create(create_dict_from_dotkey(key, path)))
    return cfg
示例#3
0
def load_scheduler(
    cfg: Config, optimizer: Optimizer
) -> Union[torch.optim.lr_scheduler.LambdaLR, Type[DummyScheduler]]:
    cls_str = get_by_dotkey(cfg, "scheduler.class")
    if not cls_str:
        return DummyScheduler
    cls = cast(Type[torch.optim.lr_scheduler.LambdaLR], import_attr(cls_str))
    params = OmegaConf.to_container(cfg.scheduler.params) or {}
    return cls(optimizer, **params)