Esempio n. 1
0
def build_optimizer_and_scheduler(
        conf: OmegaConf, model: torch.nn.Module) -> Tuple[Optimizer, Any]:
    conf = OmegaConf.to_container(conf, resolve=True)

    optimizer_fn = import_(conf.pop("name"))
    lr = conf.pop("lr") if "lr" in conf else 1e-3
    optimizer = optimizer_fn(
        [
            {
                "params": model.backbone.parameters(),
                "lr": lr * 0.1
            },
            {
                "params": model.classifier.parameters(),
                "lr": lr
            },
        ],
        **conf,
    )

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode="min",
                                                           patience=2)

    return optimizer, scheduler
Esempio n. 2
0
def process_config(cfg: OmegaConf):
    if 'name' not in cfg or cfg.name is None:
        raise ValueError(
            "`cfg.name` must be provided to save a model checkpoint")

    if 'checkpoint_paths' not in cfg or cfg.checkpoint_paths is None:
        raise ValueError(
            "`cfg.checkpoint_paths` must be provided as a list of one or more str paths to "
            "pytorch lightning checkpoints")

    save_ckpt_only = False

    with open_dict(cfg):
        name_prefix = cfg.name
        checkpoint_paths = cfg.pop('checkpoint_paths')

        if 'save_ckpt_only' in cfg:
            save_ckpt_only = cfg.pop('save_ckpt_only')

    if type(checkpoint_paths) not in (list, tuple):
        checkpoint_paths = str(checkpoint_paths).replace("[",
                                                         "").replace("]", "")
        checkpoint_paths = checkpoint_paths.split(",")
        checkpoint_paths = [
            ckpt_path.strip() for ckpt_path in checkpoint_paths
        ]

    return name_prefix, checkpoint_paths, save_ckpt_only
Esempio n. 3
0
def _convert_config(cfg: OmegaConf):
    """ Recursive function convertint the configuration from old hydra format to the new one. """

    # Get rid of cls -> _target_.
    if 'cls' in cfg and "_target_" not in cfg:
        cfg._target_ = cfg.pop("cls")

    # Get rid of params.
    if 'params' in cfg:
        params = cfg.pop('params')
        for param_key, param_val in params.items():
            cfg[param_key] = param_val

    # Recursion.
    try:
        for _, sub_cfg in cfg.items():
            if isinstance(sub_cfg, DictConfig):
                _convert_config(sub_cfg)
    except omegaconf_errors.OmegaConfBaseException as e:
        logging.warning(f"Skipping config conversion for cfg:\n{cfg}\n due to OmegaConf error encountered :\n{e}.")
Esempio n. 4
0
def build_optimizer_and_scheduler(
        conf: OmegaConf, model: torch.nn.Module) -> Tuple[Optimizer, Any]:
    conf = OmegaConf.to_container(conf, resolve=True)

    optimizer_fn = import_(conf.pop("name"))
    optimizer = optimizer_fn(model.parameters(), **conf)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode="min",
                                                           patience=2)

    return optimizer, scheduler