def build_config(config_path: (str, list), log_dir: str, ckpt_path: str) -> [dict, str]: """ Function to create new directory to log directory to store results. :param config_path: string or list of strings, path of configuration files :param log_dir: string, path to store logs. :param ckpt_path: str, path where model is stored. :return: - config, configuration dictionary - log_dir, path of the directory for saving outputs """ # check ckpt_path if not ckpt_path.endswith(".ckpt"): raise ValueError( "checkpoint path should end with .ckpt, got {}".format(ckpt_path)) # init log directory log_dir = build_log_dir(log_dir) # load config if config_path == "": # use default config, which should be provided in the log folder config = config_parser.load_configs( "/".join(ckpt_path.split("/")[:-2]) + "/config.yaml") else: # use customized config logging.warning( "Using customized configuration." "The code might break if the config of the model doesn't match the saved model." ) config = config_parser.load_configs(config_path) return config, log_dir
def init(log_dir, ckpt_path, config_path): """ Function to create new directory to log directory to store results. :param log_dir: string, path to store logs. :param ckpt_path: str, path where model is stored. """ # check ckpt_path if not ckpt_path.endswith(".ckpt"): raise ValueError( "checkpoint path should end with .ckpt, got {}".format(ckpt_path)) # init log directory log_dir = os.path.join( "logs", datetime.now().strftime("%Y%m%d-%H%M%S") if log_dir == "" else log_dir) if os.path.exists(log_dir): logging.warning("Log directory {} exists already.".format(log_dir)) else: os.makedirs(log_dir) # load config if config_path == "": # use default config, which should be provided in the log folder config = config_parser.load_configs( "/".join(ckpt_path.split("/")[:-2]) + "/config.yaml") else: # use customized config logging.warning( "Using customized configuration." "The code might break if the config of the model doesn't match the saved model." ) config = config_parser.load_configs(config_path) return config, log_dir
def build_config(config_path: (str, list), log_root: str, log_dir: str, ckpt_path: str) -> [dict, str]: """ Function to create new directory to log directory to store results. :param config_path: string or list of strings, path of configuration files :param log_root: str, root of logs :param log_dir: string, path to store logs. :param ckpt_path: str, path where model is stored. :return: - config, configuration dictionary - log_dir, path of the directory for saving outputs """ # init log directory log_dir = build_log_dir(log_root=log_root, log_dir=log_dir) # replace the ~ with user home path ckpt_path = os.path.expanduser(ckpt_path) # load config if config_path == "": # use default config, which should be provided in the log folder config = config_parser.load_configs( "/".join(ckpt_path.split("/")[:-2]) + "/config.yaml") else: # use customized config logging.warning( "Using customized configuration." "The code might break if the config doesn't match the saved model." ) config = config_parser.load_configs(config_path) return config, log_dir, ckpt_path
def build_config(config_path: Union[str, List[str]], log_dir: str, exp_name: str, ckpt_path: str) -> Tuple[Dict, str, str]: """ Function to create new directory to log directory to store results. :param config_path: path of configuration files. :param log_dir: path of the log directory. :param exp_name: experiment name. :param ckpt_path: path where model is stored. :return: - config, configuration dictionary. - exp_name, path of the directory for saving outputs. """ # init log directory log_dir = build_log_dir(log_dir=log_dir, exp_name=exp_name) # replace the ~ with user home path ckpt_path = os.path.expanduser(ckpt_path) # load config if config_path == "": # use default config, which should be provided in the log folder config = config_parser.load_configs( "/".join(ckpt_path.split("/")[:-2]) + "/config.yaml") else: # use customized config logging.warning( "Using customized configuration." "The code might break if the config doesn't match the saved model." ) config = config_parser.load_configs(config_path) return config, log_dir, ckpt_path
def test_multi_configs(self): expected = load_configs( config_path="deepreg/config/unpaired_labeled_ddf.yaml") got = load_configs(config_path=[ "deepreg/config/test/ddf.yaml", "deepreg/config/test/unpaired_nifti.yaml", "deepreg/config/test/labeled.yaml", ]) self.assertEqual(got, expected)
def build_config(config_path: (str, list), log_dir: str, ckpt_path: str) -> [dict, str]: """ Function to initialise log directories, assert that checkpointed model is the right type and to parse the configuration for training :param config_path: list of str, path to config file :param log_dir: str, path to where training logs to be stored. :param ckpt_path: str, path where model is stored. :return: - config: a dictionary saving configuration - log_dir: the path of directory to save logs """ # init log directory log_dir = build_log_dir(log_dir) # check checkpoint path if ckpt_path != "": if not ckpt_path.endswith(".ckpt"): raise ValueError("checkpoint path should end with .ckpt") # load and backup config config = config_parser.load_configs(config_path) config_parser.save(config=config, out_dir=log_dir) return config, log_dir
def init(config_path, log_dir, ckpt_path): """ Function to initialise log directories, assert that checkpointed model is the right type and to parse the configuration for training :param config_path: list of str, path to config file :param log_dir: str, path to where training logs to be stored. :param ckpt_path: str, path where model is stored. """ # init log directory log_dir = os.path.join( "logs", datetime.now().strftime("%Y%m%d-%H%M%S") if log_dir == "" else log_dir) if os.path.exists(log_dir): logging.warning("Log directory {} exists already.".format(log_dir)) else: os.makedirs(log_dir) # check checkpoint path if ckpt_path != "": if not ckpt_path.endswith(".ckpt"): raise ValueError("checkpoint path should end with .ckpt") # load and backup config config = config_parser.load_configs(config_path) config_parser.save(config=config, out_dir=log_dir) return config, log_dir
def test_outdated_config(self): with open("demos/grouped_mr_heart/grouped_mr_heart.yaml") as file: expected = yaml.load(file, Loader=yaml.FullLoader) got = load_configs("config/test/grouped_mr_heart_v011.yaml") assert got == expected updated_file_path = "config/test/updated_grouped_mr_heart_v011.yaml" assert os.path.isfile(updated_file_path) os.remove(updated_file_path)
def test_multiple_configs(self): with open("config/unpaired_labeled_ddf.yaml") as file: expected = yaml.load(file, Loader=yaml.FullLoader) got = load_configs(config_path=[ "config/test/ddf.yaml", "config/test/unpaired_nifti.yaml", "config/test/labeled.yaml", ]) assert got == expected
def test_load_configs(): """ test load_configs by checking outputs """ # single config # input is str not list with open("deepreg/config/unpaired_labeled_ddf.yaml") as file: expected = yaml.load(file, Loader=yaml.FullLoader) got = load_configs("deepreg/config/unpaired_labeled_ddf.yaml") assert got == expected # multiple configs with open("deepreg/config/unpaired_labeled_ddf.yaml") as file: expected = yaml.load(file, Loader=yaml.FullLoader) got = load_configs(config_path=[ "deepreg/config/test/ddf.yaml", "deepreg/config/test/unpaired_nifti.yaml", "deepreg/config/test/labeled.yaml", ]) assert got == expected
def build_config( config_path: (str, list), log_root: str, log_dir: str, ckpt_path: str, max_epochs: int = -1, ) -> [dict, str]: """ Function to initialise log directories, assert that checkpointed model is the right type and to parse the configuration for training. :param config_path: list of str, path to config file :param log_root: root of logs :param log_dir: path to where training logs to be stored. :param ckpt_path: path where model is stored. :param max_epochs: if max_epochs > 0, use it to overwrite the configuration :return: - config: a dictionary saving configuration - log_dir: the path of directory to save logs """ # init log directory log_dir = build_log_dir(log_root=log_root, log_dir=log_dir) # load config config = config_parser.load_configs(config_path) # replace the ~ with user home path ckpt_path = os.path.expanduser(ckpt_path) # overwrite epochs and save_period if necessary if max_epochs > 0: config["train"]["epochs"] = max_epochs config["train"]["save_period"] = min(max_epochs, config["train"]["save_period"]) # backup config config_parser.save(config=config, out_dir=log_dir) # batch_size in original config corresponds to batch_size per GPU gpus = tf.config.experimental.list_physical_devices("GPU") config["train"]["preprocess"]["batch_size"] *= max(len(gpus), 1) return config, log_dir, ckpt_path
def build_config( config_path: Union[str, List[str]], log_dir: str, exp_name: str, ckpt_path: str, max_epochs: int = -1, ) -> Tuple[Dict, str, str]: """ Function to initialise log directories, assert that checkpointed model is the right type and to parse the configuration for training. :param config_path: list of str, path to config file :param log_dir: path of the log directory :param exp_name: name of the experiment :param ckpt_path: path where model is stored. :param max_epochs: if max_epochs > 0, use it to overwrite the configuration :return: - config: a dictionary saving configuration - exp_name: the path of directory to save logs """ # init log directory log_dir = build_log_dir(log_dir=log_dir, exp_name=exp_name) # load config config = config_parser.load_configs(config_path) # replace the ~ with user home path ckpt_path = os.path.expanduser(ckpt_path) # overwrite epochs and save_period if necessary if max_epochs > 0: config["train"]["epochs"] = max_epochs config["train"]["save_period"] = min(max_epochs, config["train"]["save_period"]) # backup config config_parser.save(config=config, out_dir=log_dir) return config, log_dir, ckpt_path
def main(args=None): """Entry point for gen_tfrecord""" parser = argparse.ArgumentParser(description="gen_tfrecord") ## ADD POSITIONAL ARGUMENTS parser.add_argument("--config_path", "-c", help="Path of config", type=str, required=True) parser.add_argument( "--examples_per_tfrecord", "-n", help="Number of examples per tfrecord", type=int, default=64, ) args = parser.parse_args(args) config = config_parser.load_configs(args.config_path) dataset_config = config["dataset"] tfrecord_dir = dataset_config["tfrecord_dir"] dataset_config["tfrecord_dir"] = "" if os.path.exists(tfrecord_dir) and os.path.isdir(tfrecord_dir): remove = input("%s exists. Remove it or not? Y/N\n" % tfrecord_dir) if remove.lower() == "y": shutil.rmtree(tfrecord_dir) for mode in ["train", "valid", "test"]: data_loader = load.get_data_loader(dataset_config, mode) write_tfrecords( data_dir=os.path.join(tfrecord_dir, mode), data_generator=data_loader.data_generator(), examples_per_tfrecord=args.examples_per_tfrecord, )
def test_single_config(self): with open("config/unpaired_labeled_ddf.yaml") as file: expected = yaml.load(file, Loader=yaml.FullLoader) got = load_configs("config/unpaired_labeled_ddf.yaml") assert got == expected