def test_build_config(): """ Test build_config and check log_dir setting and checkpoint path verification """ config_path = "deepreg/config/unpaired_labeled_ddf.yaml" log_dir = "test_build_config" # checkpoint path empty got_config, got_log_dir = build_config(config_path=config_path, log_dir=log_dir, ckpt_path="") assert isinstance(got_config, dict) assert got_log_dir == os.path.join("logs", log_dir) # checkpoint path ends with ckpt got_config, got_log_dir = build_config(config_path=config_path, log_dir=log_dir, ckpt_path="example.ckpt") assert isinstance(got_config, dict) assert got_log_dir == os.path.join("logs", log_dir) # checkpoint path ends with h5 with pytest.raises(ValueError): build_config(config_path=config_path, log_dir=log_dir, ckpt_path="example.h5")
def test_ckpt_path_err(self): # checkpoint path ends with h5 with pytest.raises(ValueError) as err_info: build_config( config_path=self.config_path, log_root=log_root, log_dir=self.log_dir, ckpt_path="example.h5", ) assert "checkpoint path should end with .ckpt" in str(err_info.value)
def test_build_dataset(): """ Test build_dataset by checking the output types """ # init arguments config_path = "deepreg/config/unpaired_labeled_ddf.yaml" log_dir = "test_build_dataset" ckpt_path = "" # load config config, log_dir = build_config( config_path=config_path, log_dir=log_dir, ckpt_path=ckpt_path ) # build dataset data_out_train, data_out_val = build_dataset( dataset_config=config["dataset"], preprocess_config=config["train"]["preprocess"], ) data_loader_train, dataset_train, steps_per_epoch_train = data_out_train data_loader_val, dataset_val, steps_per_epoch_val = data_out_val # check output types assert isinstance(data_loader_train, DataLoader) assert isinstance(dataset_train, tf.data.Dataset) assert isinstance(steps_per_epoch_train, int) assert isinstance(data_loader_val, DataLoader) assert isinstance(dataset_val, tf.data.Dataset) assert isinstance(steps_per_epoch_val, int)
def test_max_epochs(self, max_epochs, expected_epochs, expected_save_period): got_config, _ = build_config( config_path=self.config_path, log_root=log_root, log_dir=self.log_dir, ckpt_path="", max_epochs=max_epochs, ) assert got_config["train"]["epochs"] == expected_epochs assert got_config["train"]["save_period"] == expected_save_period
def test_ckpt_path(self, ckpt_path): # check the code can pass got_config, got_log_dir = build_config( config_path=self.config_path, log_root=log_root, log_dir=self.log_dir, ckpt_path=ckpt_path, ) assert isinstance(got_config, dict) assert got_log_dir == os.path.join(log_root, self.log_dir)
def test_ckpt_path(self, ckpt_path): # check the code can pass got_config, got_log_dir, _ = build_config( config_path=self.config_path, log_dir=self.log_dir, exp_name=self.exp_name, ckpt_path=ckpt_path, ) assert isinstance(got_config, dict) assert got_log_dir == os.path.join(self.log_dir, self.exp_name)
def test_build_dataset(): """ Test build_dataset by checking the output types """ # init arguments config_path = "config/unpaired_labeled_ddf.yaml" log_dir = "logs" exp_name = "test_build_dataset" ckpt_path = "" # load config config, _, _ = build_config(config_path=config_path, log_dir=log_dir, exp_name=exp_name, ckpt_path=ckpt_path) # build dataset data_loader_train, dataset_train, steps_per_epoch_train = build_dataset( dataset_config=config["dataset"], preprocess_config=config["train"]["preprocess"], split="train", training=False, repeat=False, ) # check output types assert isinstance(data_loader_train, DataLoader) assert isinstance(dataset_train, tf.data.Dataset) assert isinstance(steps_per_epoch_train, int) # remove valid data config["dataset"]["valid"]["dir"] = "" # build dataset data_loader_valid, dataset_valid, steps_per_epoch_valid = build_dataset( dataset_config=config["dataset"], preprocess_config=config["train"]["preprocess"], split="valid", training=False, repeat=False, ) assert data_loader_valid is None assert dataset_valid is None assert steps_per_epoch_valid is None