コード例 #1
0
def build_config(config_path: (str, list), log_dir: str,
                 ckpt_path: str) -> [dict, str]:
    """
    Function to create new directory to log directory to store results.

    :param config_path: string or list of strings, path of configuration files
    :param log_dir: string, path to store logs.
    :param ckpt_path: str, path where model is stored.
    :return: - config, configuration dictionary
             - log_dir, path of the directory for saving outputs
    """
    # check ckpt_path
    if not ckpt_path.endswith(".ckpt"):
        raise ValueError(
            "checkpoint path should end with .ckpt, got {}".format(ckpt_path))

    # init log directory
    log_dir = build_log_dir(log_dir)

    # load config
    if config_path == "":
        # use default config, which should be provided in the log folder
        config = config_parser.load_configs(
            "/".join(ckpt_path.split("/")[:-2]) + "/config.yaml")
    else:
        # use customized config
        logging.warning(
            "Using customized configuration."
            "The code might break if the config of the model doesn't match the saved model."
        )
        config = config_parser.load_configs(config_path)
    return config, log_dir
コード例 #2
0
ファイル: predict.py プロジェクト: knvsmadhav/DeepReg
def init(log_dir, ckpt_path, config_path):
    """
    Function to create new directory to log directory
    to store results.
    :param log_dir: string, path to store logs.
    :param ckpt_path: str, path where model is stored.
    """
    # check ckpt_path
    if not ckpt_path.endswith(".ckpt"):
        raise ValueError(
            "checkpoint path should end with .ckpt, got {}".format(ckpt_path))

    # init log directory
    log_dir = os.path.join(
        "logs",
        datetime.now().strftime("%Y%m%d-%H%M%S") if log_dir == "" else log_dir)
    if os.path.exists(log_dir):
        logging.warning("Log directory {} exists already.".format(log_dir))
    else:
        os.makedirs(log_dir)

    # load config
    if config_path == "":
        # use default config, which should be provided in the log folder
        config = config_parser.load_configs(
            "/".join(ckpt_path.split("/")[:-2]) + "/config.yaml")
    else:
        # use customized config
        logging.warning(
            "Using customized configuration."
            "The code might break if the config of the model doesn't match the saved model."
        )
        config = config_parser.load_configs(config_path)
    return config, log_dir
コード例 #3
0
def build_config(config_path: (str, list), log_root: str, log_dir: str,
                 ckpt_path: str) -> [dict, str]:
    """
    Function to create new directory to log directory to store results.

    :param config_path: string or list of strings, path of configuration files
    :param log_root: str, root of logs
    :param log_dir: string, path to store logs.
    :param ckpt_path: str, path where model is stored.
    :return: - config, configuration dictionary
             - log_dir, path of the directory for saving outputs
    """

    # init log directory
    log_dir = build_log_dir(log_root=log_root, log_dir=log_dir)

    # replace the ~ with user home path
    ckpt_path = os.path.expanduser(ckpt_path)

    # load config
    if config_path == "":
        # use default config, which should be provided in the log folder
        config = config_parser.load_configs(
            "/".join(ckpt_path.split("/")[:-2]) + "/config.yaml")
    else:
        # use customized config
        logging.warning(
            "Using customized configuration."
            "The code might break if the config doesn't match the saved model."
        )
        config = config_parser.load_configs(config_path)
    return config, log_dir, ckpt_path
コード例 #4
0
def build_config(config_path: Union[str, List[str]], log_dir: str,
                 exp_name: str, ckpt_path: str) -> Tuple[Dict, str, str]:
    """
    Function to create new directory to log directory to store results.

    :param config_path: path of configuration files.
    :param log_dir: path of the log directory.
    :param exp_name: experiment name.
    :param ckpt_path: path where model is stored.
    :return: - config, configuration dictionary.
             - exp_name, path of the directory for saving outputs.
    """

    # init log directory
    log_dir = build_log_dir(log_dir=log_dir, exp_name=exp_name)

    # replace the ~ with user home path
    ckpt_path = os.path.expanduser(ckpt_path)

    # load config
    if config_path == "":
        # use default config, which should be provided in the log folder
        config = config_parser.load_configs(
            "/".join(ckpt_path.split("/")[:-2]) + "/config.yaml")
    else:
        # use customized config
        logging.warning(
            "Using customized configuration."
            "The code might break if the config doesn't match the saved model."
        )
        config = config_parser.load_configs(config_path)
    return config, log_dir, ckpt_path
コード例 #5
0
 def test_multi_configs(self):
     expected = load_configs(
         config_path="deepreg/config/unpaired_labeled_ddf.yaml")
     got = load_configs(config_path=[
         "deepreg/config/test/ddf.yaml",
         "deepreg/config/test/unpaired_nifti.yaml",
         "deepreg/config/test/labeled.yaml",
     ])
     self.assertEqual(got, expected)
コード例 #6
0
ファイル: train.py プロジェクト: shannonxtreme/DeepReg
def build_config(config_path: (str, list), log_dir: str, ckpt_path: str) -> [dict, str]:
    """
    Function to initialise log directories,
    assert that checkpointed model is the right
    type and to parse the configuration for training

    :param config_path: list of str, path to config file
    :param log_dir: str, path to where training logs to be stored.
    :param ckpt_path: str, path where model is stored.
    :return: - config: a dictionary saving configuration
             - log_dir: the path of directory to save logs
    """

    # init log directory
    log_dir = build_log_dir(log_dir)

    # check checkpoint path
    if ckpt_path != "":
        if not ckpt_path.endswith(".ckpt"):
            raise ValueError("checkpoint path should end with .ckpt")

    # load and backup config
    config = config_parser.load_configs(config_path)
    config_parser.save(config=config, out_dir=log_dir)
    return config, log_dir
コード例 #7
0
ファイル: train.py プロジェクト: acasamitjana/DeepReg
def init(config_path, log_dir, ckpt_path):
    """
    Function to initialise log directories,
    assert that checkpointed model is the right
    type and to parse the configuration for training
    :param config_path: list of str, path to config file
    :param log_dir: str, path to where training logs
                    to be stored.
    :param ckpt_path: str, path where model is stored.
    """

    # init log directory
    log_dir = os.path.join(
        "logs",
        datetime.now().strftime("%Y%m%d-%H%M%S") if log_dir == "" else log_dir)
    if os.path.exists(log_dir):
        logging.warning("Log directory {} exists already.".format(log_dir))
    else:
        os.makedirs(log_dir)

    # check checkpoint path
    if ckpt_path != "":
        if not ckpt_path.endswith(".ckpt"):
            raise ValueError("checkpoint path should end with .ckpt")

    # load and backup config
    config = config_parser.load_configs(config_path)
    config_parser.save(config=config, out_dir=log_dir)
    return config, log_dir
コード例 #8
0
 def test_outdated_config(self):
     with open("demos/grouped_mr_heart/grouped_mr_heart.yaml") as file:
         expected = yaml.load(file, Loader=yaml.FullLoader)
     got = load_configs("config/test/grouped_mr_heart_v011.yaml")
     assert got == expected
     updated_file_path = "config/test/updated_grouped_mr_heart_v011.yaml"
     assert os.path.isfile(updated_file_path)
     os.remove(updated_file_path)
コード例 #9
0
 def test_multiple_configs(self):
     with open("config/unpaired_labeled_ddf.yaml") as file:
         expected = yaml.load(file, Loader=yaml.FullLoader)
     got = load_configs(config_path=[
         "config/test/ddf.yaml",
         "config/test/unpaired_nifti.yaml",
         "config/test/labeled.yaml",
     ])
     assert got == expected
コード例 #10
0
def test_load_configs():
    """
    test load_configs by checking outputs
    """
    # single config
    # input is str not list
    with open("deepreg/config/unpaired_labeled_ddf.yaml") as file:
        expected = yaml.load(file, Loader=yaml.FullLoader)
    got = load_configs("deepreg/config/unpaired_labeled_ddf.yaml")
    assert got == expected

    # multiple configs
    with open("deepreg/config/unpaired_labeled_ddf.yaml") as file:
        expected = yaml.load(file, Loader=yaml.FullLoader)
    got = load_configs(config_path=[
        "deepreg/config/test/ddf.yaml",
        "deepreg/config/test/unpaired_nifti.yaml",
        "deepreg/config/test/labeled.yaml",
    ])
    assert got == expected
コード例 #11
0
ファイル: train.py プロジェクト: snehashis1997/DeepReg
def build_config(
    config_path: (str, list),
    log_root: str,
    log_dir: str,
    ckpt_path: str,
    max_epochs: int = -1,
) -> [dict, str]:
    """
    Function to initialise log directories,
    assert that checkpointed model is the right
    type and to parse the configuration for training.

    :param config_path: list of str, path to config file
    :param log_root: root of logs
    :param log_dir: path to where training logs to be stored.
    :param ckpt_path: path where model is stored.
    :param max_epochs: if max_epochs > 0, use it to overwrite the configuration
    :return: - config: a dictionary saving configuration
             - log_dir: the path of directory to save logs
    """

    # init log directory
    log_dir = build_log_dir(log_root=log_root, log_dir=log_dir)

    # load config
    config = config_parser.load_configs(config_path)

    # replace the ~ with user home path
    ckpt_path = os.path.expanduser(ckpt_path)

    # overwrite epochs and save_period if necessary
    if max_epochs > 0:
        config["train"]["epochs"] = max_epochs
        config["train"]["save_period"] = min(max_epochs,
                                             config["train"]["save_period"])

    # backup config
    config_parser.save(config=config, out_dir=log_dir)

    # batch_size in original config corresponds to batch_size per GPU
    gpus = tf.config.experimental.list_physical_devices("GPU")
    config["train"]["preprocess"]["batch_size"] *= max(len(gpus), 1)

    return config, log_dir, ckpt_path
コード例 #12
0
def build_config(
    config_path: Union[str, List[str]],
    log_dir: str,
    exp_name: str,
    ckpt_path: str,
    max_epochs: int = -1,
) -> Tuple[Dict, str, str]:
    """
    Function to initialise log directories,
    assert that checkpointed model is the right
    type and to parse the configuration for training.

    :param config_path: list of str, path to config file
    :param log_dir: path of the log directory
    :param exp_name: name of the experiment
    :param ckpt_path: path where model is stored.
    :param max_epochs: if max_epochs > 0, use it to overwrite the configuration
    :return: - config: a dictionary saving configuration
             - exp_name: the path of directory to save logs
    """

    # init log directory
    log_dir = build_log_dir(log_dir=log_dir, exp_name=exp_name)

    # load config
    config = config_parser.load_configs(config_path)

    # replace the ~ with user home path
    ckpt_path = os.path.expanduser(ckpt_path)

    # overwrite epochs and save_period if necessary
    if max_epochs > 0:
        config["train"]["epochs"] = max_epochs
        config["train"]["save_period"] = min(max_epochs,
                                             config["train"]["save_period"])

    # backup config
    config_parser.save(config=config, out_dir=log_dir)

    return config, log_dir, ckpt_path
コード例 #13
0
ファイル: gen_tfrecord.py プロジェクト: acasamitjana/DeepReg
def main(args=None):
    """Entry point for gen_tfrecord"""

    parser = argparse.ArgumentParser(description="gen_tfrecord")

    ## ADD POSITIONAL ARGUMENTS
    parser.add_argument("--config_path",
                        "-c",
                        help="Path of config",
                        type=str,
                        required=True)

    parser.add_argument(
        "--examples_per_tfrecord",
        "-n",
        help="Number of examples per tfrecord",
        type=int,
        default=64,
    )

    args = parser.parse_args(args)

    config = config_parser.load_configs(args.config_path)
    dataset_config = config["dataset"]
    tfrecord_dir = dataset_config["tfrecord_dir"]
    dataset_config["tfrecord_dir"] = ""

    if os.path.exists(tfrecord_dir) and os.path.isdir(tfrecord_dir):
        remove = input("%s exists. Remove it or not? Y/N\n" % tfrecord_dir)
        if remove.lower() == "y":
            shutil.rmtree(tfrecord_dir)
    for mode in ["train", "valid", "test"]:
        data_loader = load.get_data_loader(dataset_config, mode)
        write_tfrecords(
            data_dir=os.path.join(tfrecord_dir, mode),
            data_generator=data_loader.data_generator(),
            examples_per_tfrecord=args.examples_per_tfrecord,
        )
コード例 #14
0
 def test_single_config(self):
     with open("config/unpaired_labeled_ddf.yaml") as file:
         expected = yaml.load(file, Loader=yaml.FullLoader)
     got = load_configs("config/unpaired_labeled_ddf.yaml")
     assert got == expected