Exemplo n.º 1
0
def _get_cfg(
    base_cfg_path: Path,
    template_cfg_path: Optional[Path],
    run_cfg_path: Optional[Path],
) -> CFG_T:
    """Load config either from template file or from run file.

    If both template_cfg_path and run_cfg_path are None, an empty
    config is returned.

    Args:
        base_cfg_path: The path to the directory with all the config files.
        template_cfg_path: The path to the template config file for this run.
        run_cfg_path: The path to the config file created by the user for this run.

    Returns:
        The loaded config.
    """
    if run_cfg_path is not None:
        return ConfigHandler.load_cfg(run_cfg_path, base_cfg_path)
    elif template_cfg_path is not None:
        template_cfg = ConfigHandler.load_cfg(template_cfg_path, base_cfg_path)
        tui = TUI(template_cfg, base_cfg_path)
        return tui.show()
    else:
        return {}
Exemplo n.º 2
0
def _create_out_dir_and_save_run_file(
    cfg: CFG_T,
    out_dir_root: str,
    run_cfg_path: Optional[Path],
) -> None:
    """Create output directory for the current run.

    A new directory is created for the current run
    and the run file is saved in it (if needed).

    Args:
        cfg: The loaded config.
        out_dir_root: The root for output directories.
        run_cfg_path: The path to the config file created by the user for this run.

    Raises:
        ValueError: If the run name is not specified in the given config.
    """
    run_name = cfg.get(RUN_NAME_KEY, "")
    if run_name == "":
        msg = f"The config must contain a valid name for the run (key={RUN_NAME_KEY})."
        raise ValueError(msg)

    run_dir = Path(out_dir_root) / Path(run_name)
    run_file = run_dir / RUN_FILE_NAME

    create_dir = True
    if run_cfg_path is not None:
        create_dir = run_file.absolute() != run_cfg_path.absolute()

    if create_dir:
        run_dir.mkdir(parents=True, exist_ok=False)
        cfg[OUT_DIR_KEY] = str(run_dir.absolute())
        ConfigHandler.save_cfg(cfg, run_file)
Exemplo n.º 3
0
def test_replace_bases() -> None:
    cfg = {"base": "bases.a", "p1": 1}
    base_cfgs = {
        "bases": {
            "a": {
                "base": "bases.b",
                "p2": 1.23,
            },
            "b": {
                "base": "bases.c",
                "p3": True,
            },
            "c": {
                "p4": (1, 2, 3),
            },
        }
    }

    new_cfg = ConfigHandler.replace_bases(cfg, base_cfgs)

    expected_keys = set(["p1", "p2", "p3", "p4"])
    assert expected_keys == set(new_cfg.keys())

    assert new_cfg["p1"] == 1
    assert new_cfg["p2"] == 1.23
    assert new_cfg["p3"] is True
    assert new_cfg["p4"] == (1, 2, 3)
Exemplo n.º 4
0
    def before_exit(self) -> None:
        """Save the config edited by the user.

        The config is saved in the parent app when exiting the form.
        """
        edited_cfg: CFG_T = {}

        for handler, widget in self.widgets:
            if handler is not None:
                edited_cfg = handler.update_cfg(edited_cfg, widget)

        base_cfgs = ConfigHandler.load_base_cfgs(self.parent.base_cfg_dir)
        temp_cfg = ConfigHandler.replace_bases(
            edited_cfg, base_cfgs, base_key=BaseWidgetHandler.BASE_KEY)
        run_cfg = ConfigHandler.replace_bases(temp_cfg, base_cfgs)

        self.parent.run_cfg.update(run_cfg)
Exemplo n.º 5
0
def test_replace_base() -> None:
    cfg = {"base": "cfg.a.b.c", "p1": 5}
    base_cfgs = {"cfg": {"a": {"b": {"c": {"p1": 1, "p2": 2, "p3": 3}}}}}
    new_cfg = ConfigHandler.replace_base(cfg, base_cfgs)

    expected_keys = set(["p1", "p2", "p3"])
    assert expected_keys == set(new_cfg.keys())
    assert new_cfg["p1"] == 5
    assert new_cfg["p2"] == 2
    assert new_cfg["p3"] == 3
Exemplo n.º 6
0
def test_load_cfg_file(cifar100_cfg_file: Path) -> None:
    cfg = ConfigHandler.load_cfg_file(cifar100_cfg_file)

    expected_keys = set(["name", "path", "splits"])
    assert expected_keys == set(cfg.keys())

    assert cfg["name"] == "cifar100"
    assert cfg["path"] == "/path/to/cifar100"
    assert isinstance(cfg["splits"], list)
    assert cfg["splits"][0] == 75
    assert cfg["splits"][1] == 15
    assert cfg["splits"][2] == 10
Exemplo n.º 7
0
def test_load_cfg(base_cfg_dir: Path, complex_run_file: Path) -> None:
    cfg = ConfigHandler.load_cfg(complex_run_file, base_cfg_dir)

    assert isinstance(cfg, dict)
    expected_keys = set(
        ["dataset", "run_name", "net", "optimizer", "lr", "params"])
    assert expected_keys == set(cfg.keys())

    assert cfg["optimizer"] == "adam"
    assert cfg["lr"] == 5e-3
    assert cfg["run_name"] == "test"

    assert isinstance(cfg["dataset"], dict)
    expected_keys = set(["name", "path", "splits", "classes"])
    assert expected_keys == set(cfg["dataset"].keys())
    assert cfg["dataset"]["name"] == "cifar10"
    assert cfg["dataset"]["path"] == "/path/to/cifar10"
    assert cfg["dataset"]["splits"] == [70, 20, 10]
    assert cfg["dataset"]["classes"] == [1, 5, 6]

    assert isinstance(cfg["net"], dict)
    expected_keys = set(
        ["name", "num_layers", "ckpt_path", "freeze", "use_skip"])
    assert expected_keys == set(cfg["net"].keys())
    assert cfg["net"]["name"] == "efficientnet"
    assert cfg["net"]["num_layers"] == 20
    assert cfg["net"]["ckpt_path"] == "/path/to/efficientnet"
    assert cfg["net"]["freeze"] is False
    assert cfg["net"]["use_skip"] is True

    assert isinstance(cfg["params"], dict)
    expected_keys = set(["use_bn", "use_dropout", "use_augmentation"])
    assert expected_keys == set(cfg["params"].keys())
    assert cfg["params"]["use_bn"] is True
    assert cfg["params"]["use_dropout"] is False
    assert cfg["params"]["use_augmentation"] is True
Exemplo n.º 8
0
def test_load_cfg_dir(base_cfg_dir: Path) -> None:
    cfg = ConfigHandler.load_cfg_dir(base_cfg_dir)

    assert isinstance(cfg, dict)
    expected_keys = set(["var", "dataset", "net", "params"])
    assert expected_keys == set(cfg.keys())

    assert isinstance(cfg["var"], dict)
    expected_keys = set(["optimizer", "lr"])
    assert expected_keys == set(cfg["var"].keys())

    assert cfg["var"]["optimizer"] == "adam"
    assert cfg["var"]["lr"] == 1e-3

    assert isinstance(cfg["dataset"], dict)
    expected_keys = set(["cifar", "imagenet"])
    assert expected_keys == set(cfg["dataset"].keys())

    assert isinstance(cfg["dataset"]["imagenet"], dict)
    expected_keys = set(["name", "path", "splits"])
    assert expected_keys == set(cfg["dataset"]["imagenet"].keys())

    assert cfg["dataset"]["imagenet"]["name"] == "imagenet"
    assert cfg["dataset"]["imagenet"]["path"] == "/path/to/imagenet"
    assert isinstance(cfg["dataset"]["imagenet"]["splits"], list)
    assert cfg["dataset"]["imagenet"]["splits"][0] == 80
    assert cfg["dataset"]["imagenet"]["splits"][1] == 10
    assert cfg["dataset"]["imagenet"]["splits"][2] == 10

    assert isinstance(cfg["dataset"]["cifar"], dict)
    expected_keys = set(["cifar10", "cifar100"])
    assert expected_keys == set(cfg["dataset"]["cifar"].keys())

    splits = {"cifar10": [70, 20, 10], "cifar100": [75, 15, 10]}
    for key in expected_keys:
        assert isinstance(cfg["dataset"]["cifar"][key], dict)
        sub_expected_keys = set(["name", "path", "splits"])
        assert sub_expected_keys == set(cfg["dataset"]["cifar"][key].keys())

        assert cfg["dataset"]["cifar"][key]["name"] == key
        assert cfg["dataset"]["cifar"][key]["path"] == "/path/to/" + key
        assert isinstance(cfg["dataset"]["cifar"][key]["splits"], list)
        assert cfg["dataset"]["cifar"][key]["splits"][0] == splits[key][0]
        assert cfg["dataset"]["cifar"][key]["splits"][1] == splits[key][1]
        assert cfg["dataset"]["cifar"][key]["splits"][2] == splits[key][2]

    assert isinstance(cfg["net"], dict)
    expected_keys = set(["resnet", "efficientnet"])
    assert expected_keys == set(cfg["net"].keys())

    assert isinstance(cfg["net"]["efficientnet"], dict)
    expected_keys = set(["name", "num_layers", "ckpt_path", "use_skip"])
    assert expected_keys == set(cfg["net"]["efficientnet"].keys())

    assert cfg["net"]["efficientnet"]["name"] == "efficientnet"
    assert cfg["net"]["efficientnet"]["num_layers"] == 20
    assert cfg["net"]["efficientnet"]["ckpt_path"] == "/path/to/efficientnet"
    assert cfg["net"]["efficientnet"]["use_skip"] is True

    assert isinstance(cfg["net"]["resnet"], dict)
    expected_keys = set(["resnet18", "resnet101"])
    assert expected_keys == set(cfg["net"]["resnet"].keys())

    num_layers = {"resnet18": 18, "resnet101": 101}
    for key in expected_keys:
        assert isinstance(cfg["net"]["resnet"][key], dict)
        sub_expected_keys = set(["name", "num_layers", "ckpt_path", "base"])
        assert sub_expected_keys == set(cfg["net"]["resnet"][key].keys())

        assert cfg["net"]["resnet"][key]["name"] == key
        assert cfg["net"]["resnet"][key]["num_layers"] == num_layers[key]
        assert cfg["net"]["resnet"][key]["ckpt_path"] == "/path/to/" + key
        assert cfg["net"]["resnet"][key]["base"] == "net.efficientnet"

    assert isinstance(cfg["params"], dict)
    expected_keys = set(["default", "test", "train"])
    assert expected_keys == set(cfg["params"].keys())

    assert isinstance(cfg["params"]["default"], dict)
    expected_keys = set(["use_bn", "use_dropout"])
    assert expected_keys == set(cfg["params"]["default"].keys())

    assert cfg["params"]["default"]["use_bn"] is True
    assert cfg["params"]["default"]["use_dropout"] is False

    assert isinstance(cfg["params"]["test"], dict)
    expected_keys = set(["base", "use_augmentation"])
    assert expected_keys == set(cfg["params"]["test"].keys())

    assert cfg["params"]["test"]["base"] == "params.default"
    assert cfg["params"]["test"]["use_augmentation"] is False

    assert isinstance(cfg["params"]["train"], dict)
    expected_keys = set(["base", "use_augmentation"])
    assert expected_keys == set(cfg["params"]["train"].keys())

    assert cfg["params"]["train"]["base"] == "params.default"
    assert cfg["params"]["train"]["use_augmentation"] is True
Exemplo n.º 9
0
def test_load_cfg_file_exception(wrong_run_file: Path) -> None:
    with pytest.raises(ValueError):
        ConfigHandler.load_cfg_file(wrong_run_file)
Exemplo n.º 10
0
def test_replace_base_exception() -> None:
    cfg = {"base": "cfg.a.b.d", "p1": 5}
    base_cfgs = {"cfg": {"a": {"b": {"c": {"p1": 1, "p2": 2, "p3": 3}}}}}

    with pytest.raises(ValueError):
        ConfigHandler.replace_base(cfg, base_cfgs)