Beispiel #1
0
def build_config(config_path: (str, list), log_dir: str,
                 ckpt_path: str) -> [dict, str]:
    """
    Function to initialise log directories,
    assert that checkpointed model is the right
    type and to parse the configuration for training

    :param config_path: list of str, path to config file
    :param log_dir: str, path to where training logs to be stored.
    :param ckpt_path: str, path where model is stored.
    :return: - config: a dictionary saving configuration
             - log_dir: the path of directory to save logs
    """

    # init log directory
    log_dir = build_log_dir(log_dir)

    # check checkpoint path
    if ckpt_path != "":
        if not ckpt_path.endswith(".ckpt"):
            raise ValueError("checkpoint path should end with .ckpt")

    # load and backup config
    config = config_parser.load_configs(config_path)
    config_parser.save(config=config, out_dir=log_dir)
    return config, log_dir
def build_config(
    config_path: (str, list),
    log_root: str,
    log_dir: str,
    ckpt_path: str,
    max_epochs: int = -1,
) -> [dict, str]:
    """
    Function to initialise log directories,
    assert that checkpointed model is the right
    type and to parse the configuration for training.

    :param config_path: list of str, path to config file
    :param log_root: str, root of logs
    :param log_dir: str, path to where training logs to be stored.
    :param ckpt_path: str, path where model is stored.
    :param max_epochs: int, if max_epochs > 0, will use it to overwrite the configuration
    :return: - config: a dictionary saving configuration
             - log_dir: the path of directory to save logs
    """

    # init log directory
    log_dir = build_log_dir(log_root=log_root, log_dir=log_dir)

    # check checkpoint path
    if ckpt_path != "":
        if not ckpt_path.endswith(".ckpt"):
            raise ValueError(
                f"checkpoint path should end with .ckpt, got {ckpt_path}")

    # load config
    config = config_parser.load_configs(config_path)

    # overwrite epochs and save_period if necessary
    if max_epochs > 0:
        config["train"]["epochs"] = max_epochs
        config["train"]["save_period"] = min(max_epochs,
                                             config["train"]["save_period"])

    # backup config
    config_parser.save(config=config, out_dir=log_dir)

    # batch_size in original config corresponds to batch_size per GPU
    gpus = tf.config.experimental.list_physical_devices("GPU")
    config["train"]["preprocess"]["batch_size"] *= max(len(gpus), 1)

    return config, log_dir
Beispiel #3
0
def test_save():
    """test save by check error and existance of file"""
    # default file name
    with TempDirectory() as tempdir:
        save(config=dict(x=1), out_dir=tempdir.path)
        assert os.path.exists(os.path.join(tempdir.path, "config.yaml"))

    # custom file name
    with TempDirectory() as tempdir:
        save(config=dict(x=1), out_dir=tempdir.path, filename="test.yaml")
        assert os.path.exists(os.path.join(tempdir.path, "test.yaml"))

    # non yaml filename
    with TempDirectory() as tempdir:
        with pytest.raises(AssertionError):
            save(config=dict(x=1), out_dir=tempdir.path, filename="test.txt")