def build_config(config_path: (str, list), log_root: str, log_dir: str, ckpt_path: str) -> [dict, str]: """ Function to create new directory to log directory to store results. :param config_path: string or list of strings, path of configuration files :param log_root: str, root of logs :param log_dir: string, path to store logs. :param ckpt_path: str, path where model is stored. :return: - config, configuration dictionary - log_dir, path of the directory for saving outputs """ # check ckpt_path if not ckpt_path.endswith(".ckpt"): raise ValueError( "checkpoint path should end with .ckpt, got {}".format(ckpt_path)) # init log directory log_dir = build_log_dir(log_root=log_root, log_dir=log_dir) # load config if config_path == "": # use default config, which should be provided in the log folder config = config_parser.load_configs( "/".join(ckpt_path.split("/")[:-2]) + "/config.yaml") else: # use customized config logging.warning( "Using customized configuration." "The code might break if the config of the model doesn't match the saved model." ) config = config_parser.load_configs(config_path) return config, log_dir
def build_config(config_path: (str, list), log_dir: str, ckpt_path: str) -> [dict, str]: """ Function to initialise log directories, assert that checkpointed model is the right type and to parse the configuration for training :param config_path: list of str, path to config file :param log_dir: str, path to where training logs to be stored. :param ckpt_path: str, path where model is stored. :return: - config: a dictionary saving configuration - log_dir: the path of directory to save logs """ # init log directory log_dir = build_log_dir(log_dir) # check checkpoint path if ckpt_path != "": if not ckpt_path.endswith(".ckpt"): raise ValueError("checkpoint path should end with .ckpt") # load and backup config config = config_parser.load_configs(config_path) config_parser.save(config=config, out_dir=log_dir) return config, log_dir
def test_load_configs(): """ test load_configs by checking outputs """ # single config # input is str not list with open("config/unpaired_labeled_ddf.yaml") as file: expected = yaml.load(file, Loader=yaml.FullLoader) got = load_configs("config/unpaired_labeled_ddf.yaml") assert got == expected # multiple configs with open("config/unpaired_labeled_ddf.yaml") as file: expected = yaml.load(file, Loader=yaml.FullLoader) got = load_configs(config_path=[ "config/test/ddf.yaml", "config/test/unpaired_nifti.yaml", "config/test/labeled.yaml", ]) assert got == expected
def build_config( config_path: (str, list), log_root: str, log_dir: str, ckpt_path: str, max_epochs: int = -1, ) -> [dict, str]: """ Function to initialise log directories, assert that checkpointed model is the right type and to parse the configuration for training. :param config_path: list of str, path to config file :param log_root: str, root of logs :param log_dir: str, path to where training logs to be stored. :param ckpt_path: str, path where model is stored. :param max_epochs: int, if max_epochs > 0, will use it to overwrite the configuration :return: - config: a dictionary saving configuration - log_dir: the path of directory to save logs """ # init log directory log_dir = build_log_dir(log_root=log_root, log_dir=log_dir) # check checkpoint path if ckpt_path != "": if not ckpt_path.endswith(".ckpt"): raise ValueError( f"checkpoint path should end with .ckpt, got {ckpt_path}") # load config config = config_parser.load_configs(config_path) # overwrite epochs and save_period if necessary if max_epochs > 0: config["train"]["epochs"] = max_epochs config["train"]["save_period"] = min(max_epochs, config["train"]["save_period"]) # backup config config_parser.save(config=config, out_dir=log_dir) # batch_size in original config corresponds to batch_size per GPU gpus = tf.config.experimental.list_physical_devices("GPU") config["train"]["preprocess"]["batch_size"] *= max(len(gpus), 1) return config, log_dir