示例#1
0
def build_config(config_path: (str, list), log_root: str, log_dir: str,
                 ckpt_path: str) -> [dict, str]:
    """
    Function to create new directory to log directory to store results.

    :param config_path: string or list of strings, path of configuration files
    :param log_root: str, root of logs
    :param log_dir: string, path to store logs.
    :param ckpt_path: str, path where model is stored.
    :return: - config, configuration dictionary
             - log_dir, path of the directory for saving outputs
    """
    # check ckpt_path
    if not ckpt_path.endswith(".ckpt"):
        raise ValueError(
            "checkpoint path should end with .ckpt, got {}".format(ckpt_path))

    # init log directory
    log_dir = build_log_dir(log_root=log_root, log_dir=log_dir)

    # load config
    if config_path == "":
        # use default config, which should be provided in the log folder
        config = config_parser.load_configs(
            "/".join(ckpt_path.split("/")[:-2]) + "/config.yaml")
    else:
        # use customized config
        logging.warning(
            "Using customized configuration."
            "The code might break if the config of the model doesn't match the saved model."
        )
        config = config_parser.load_configs(config_path)
    return config, log_dir
示例#2
0
def build_config(config_path: (str, list), log_dir: str,
                 ckpt_path: str) -> [dict, str]:
    """
    Function to initialise log directories,
    assert that checkpointed model is the right
    type and to parse the configuration for training

    :param config_path: list of str, path to config file
    :param log_dir: str, path to where training logs to be stored.
    :param ckpt_path: str, path where model is stored.
    :return: - config: a dictionary saving configuration
             - log_dir: the path of directory to save logs
    """

    # init log directory
    log_dir = build_log_dir(log_dir)

    # check checkpoint path
    if ckpt_path != "":
        if not ckpt_path.endswith(".ckpt"):
            raise ValueError("checkpoint path should end with .ckpt")

    # load and backup config
    config = config_parser.load_configs(config_path)
    config_parser.save(config=config, out_dir=log_dir)
    return config, log_dir
示例#3
0
def test_load_configs():
    """
    test load_configs by checking outputs
    """
    # single config
    # input is str not list
    with open("config/unpaired_labeled_ddf.yaml") as file:
        expected = yaml.load(file, Loader=yaml.FullLoader)
    got = load_configs("config/unpaired_labeled_ddf.yaml")
    assert got == expected

    # multiple configs
    with open("config/unpaired_labeled_ddf.yaml") as file:
        expected = yaml.load(file, Loader=yaml.FullLoader)
    got = load_configs(config_path=[
        "config/test/ddf.yaml",
        "config/test/unpaired_nifti.yaml",
        "config/test/labeled.yaml",
    ])
    assert got == expected
def build_config(
    config_path: (str, list),
    log_root: str,
    log_dir: str,
    ckpt_path: str,
    max_epochs: int = -1,
) -> [dict, str]:
    """
    Function to initialise log directories,
    assert that checkpointed model is the right
    type and to parse the configuration for training.

    :param config_path: list of str, path to config file
    :param log_root: str, root of logs
    :param log_dir: str, path to where training logs to be stored.
    :param ckpt_path: str, path where model is stored.
    :param max_epochs: int, if max_epochs > 0, will use it to overwrite the configuration
    :return: - config: a dictionary saving configuration
             - log_dir: the path of directory to save logs
    """

    # init log directory
    log_dir = build_log_dir(log_root=log_root, log_dir=log_dir)

    # check checkpoint path
    if ckpt_path != "":
        if not ckpt_path.endswith(".ckpt"):
            raise ValueError(
                f"checkpoint path should end with .ckpt, got {ckpt_path}")

    # load config
    config = config_parser.load_configs(config_path)

    # overwrite epochs and save_period if necessary
    if max_epochs > 0:
        config["train"]["epochs"] = max_epochs
        config["train"]["save_period"] = min(max_epochs,
                                             config["train"]["save_period"])

    # backup config
    config_parser.save(config=config, out_dir=log_dir)

    # batch_size in original config corresponds to batch_size per GPU
    gpus = tf.config.experimental.list_physical_devices("GPU")
    config["train"]["preprocess"]["batch_size"] *= max(len(gpus), 1)

    return config, log_dir