Exemplo n.º 1
0
def test_grouped_mr_heart(old_config_path: str, latest_config_path: str):
    with open(old_config_path) as file:
        old_config = yaml.load(file, Loader=yaml.FullLoader)
    with open(latest_config_path) as file:
        latest_config = yaml.load(file, Loader=yaml.FullLoader)
    updated_config = parse_v011(old_config=old_config)
    assert updated_config == latest_config
Exemplo n.º 2
0
def config_sanity_check(config: dict) -> dict:
    """
    Check if the given config satisfies the requirements.

    :param config: entire config.
    """

    # check data
    data_config = config["dataset"]

    if data_config["type"] not in ["paired", "unpaired", "grouped"]:
        raise ValueError(
            f"data type must be paired / unpaired / grouped, got {type}.")

    if data_config["format"] not in ["nifti", "h5"]:
        raise ValueError(f"data format must be nifti / h5, got {format}.")

    assert "dir" in data_config
    for mode in ["train", "valid", "test"]:
        assert mode in data_config["dir"].keys()
        data_dir = data_config["dir"][mode]
        if data_dir is None:
            logging.warning(f"Data directory for {mode} is not defined.")
        if not (isinstance(data_dir, (str, list)) or data_dir is None):
            raise ValueError(
                f"data_dir for mode {mode} must be string or list of strings,"
                f"got {data_dir}.")

    # back compatibility support
    config = parse_v011(config)

    # check model
    if config["train"]["method"] == "conditional":
        if data_config["labeled"] is False:  # unlabeled
            raise ValueError(
                "For conditional model, data have to be labeled, got unlabeled data."
            )

    # loss weights should >= 0
    for name in ["image", "label", "regularization"]:
        loss_config = config["train"]["loss"][name]
        if not isinstance(loss_config, list):
            loss_config = [loss_config]

        for loss_i in loss_config:
            loss_weight = loss_i["weight"]
            if loss_weight <= 0:
                logging.warning("The %s loss weight %.2f is not positive.",
                                name, loss_weight)

    return config
Exemplo n.º 3
0
def config_sanity_check(config: dict) -> dict:
    """
    Check if the given config satisfies the requirements.

    :param config: entire config.
    """

    # back compatibility support
    config = parse_v011(config)

    # check model
    if config["train"]["method"] == "conditional":
        if config["dataset"]["train"]["labeled"] is False:  # unlabeled
            raise ValueError(
                "For conditional model, data have to be labeled, got unlabeled data."
            )

    return config