예제 #1
0
def test_build_config():
    """
    Test build_config and check log_dir setting and checkpoint path verification
    """
    config_path = "deepreg/config/unpaired_labeled_ddf.yaml"
    log_dir = "test_build_config"

    # checkpoint path empty
    got_config, got_log_dir = build_config(config_path=config_path,
                                           log_dir=log_dir,
                                           ckpt_path="")
    assert isinstance(got_config, dict)
    assert got_log_dir == os.path.join("logs", log_dir)

    # checkpoint path ends with ckpt
    got_config, got_log_dir = build_config(config_path=config_path,
                                           log_dir=log_dir,
                                           ckpt_path="example.ckpt")
    assert isinstance(got_config, dict)
    assert got_log_dir == os.path.join("logs", log_dir)

    # checkpoint path ends with h5
    with pytest.raises(ValueError):
        build_config(config_path=config_path,
                     log_dir=log_dir,
                     ckpt_path="example.h5")
예제 #2
0
 def test_ckpt_path_err(self):
     # checkpoint path ends with h5
     with pytest.raises(ValueError) as err_info:
         build_config(
             config_path=self.config_path,
             log_root=log_root,
             log_dir=self.log_dir,
             ckpt_path="example.h5",
         )
     assert "checkpoint path should end with .ckpt" in str(err_info.value)
예제 #3
0
def test_build_dataset():
    """
    Test build_dataset by checking the output types
    """

    # init arguments
    config_path = "deepreg/config/unpaired_labeled_ddf.yaml"
    log_dir = "test_build_dataset"
    ckpt_path = ""

    # load config
    config, log_dir = build_config(
        config_path=config_path, log_dir=log_dir, ckpt_path=ckpt_path
    )

    # build dataset
    data_out_train, data_out_val = build_dataset(
        dataset_config=config["dataset"],
        preprocess_config=config["train"]["preprocess"],
    )

    data_loader_train, dataset_train, steps_per_epoch_train = data_out_train
    data_loader_val, dataset_val, steps_per_epoch_val = data_out_val

    # check output types
    assert isinstance(data_loader_train, DataLoader)
    assert isinstance(dataset_train, tf.data.Dataset)
    assert isinstance(steps_per_epoch_train, int)
    assert isinstance(data_loader_val, DataLoader)
    assert isinstance(dataset_val, tf.data.Dataset)
    assert isinstance(steps_per_epoch_val, int)
예제 #4
0
 def test_max_epochs(self, max_epochs, expected_epochs, expected_save_period):
     got_config, _ = build_config(
         config_path=self.config_path,
         log_root=log_root,
         log_dir=self.log_dir,
         ckpt_path="",
         max_epochs=max_epochs,
     )
     assert got_config["train"]["epochs"] == expected_epochs
     assert got_config["train"]["save_period"] == expected_save_period
예제 #5
0
    def test_ckpt_path(self, ckpt_path):
        # check the code can pass

        got_config, got_log_dir = build_config(
            config_path=self.config_path,
            log_root=log_root,
            log_dir=self.log_dir,
            ckpt_path=ckpt_path,
        )
        assert isinstance(got_config, dict)
        assert got_log_dir == os.path.join(log_root, self.log_dir)
예제 #6
0
    def test_ckpt_path(self, ckpt_path):
        # check the code can pass

        got_config, got_log_dir, _ = build_config(
            config_path=self.config_path,
            log_dir=self.log_dir,
            exp_name=self.exp_name,
            ckpt_path=ckpt_path,
        )
        assert isinstance(got_config, dict)
        assert got_log_dir == os.path.join(self.log_dir, self.exp_name)
예제 #7
0
def test_build_dataset():
    """
    Test build_dataset by checking the output types
    """

    # init arguments
    config_path = "config/unpaired_labeled_ddf.yaml"
    log_dir = "logs"
    exp_name = "test_build_dataset"
    ckpt_path = ""

    # load config
    config, _, _ = build_config(config_path=config_path,
                                log_dir=log_dir,
                                exp_name=exp_name,
                                ckpt_path=ckpt_path)

    # build dataset
    data_loader_train, dataset_train, steps_per_epoch_train = build_dataset(
        dataset_config=config["dataset"],
        preprocess_config=config["train"]["preprocess"],
        split="train",
        training=False,
        repeat=False,
    )

    # check output types
    assert isinstance(data_loader_train, DataLoader)
    assert isinstance(dataset_train, tf.data.Dataset)
    assert isinstance(steps_per_epoch_train, int)

    # remove valid data
    config["dataset"]["valid"]["dir"] = ""

    # build dataset
    data_loader_valid, dataset_valid, steps_per_epoch_valid = build_dataset(
        dataset_config=config["dataset"],
        preprocess_config=config["train"]["preprocess"],
        split="valid",
        training=False,
        repeat=False,
    )

    assert data_loader_valid is None
    assert dataset_valid is None
    assert steps_per_epoch_valid is None