def test_data_loader(self, data_type: str, format: str): """ Test the data loader can be successfully built. :param data_type: name of data loader for registry :param format: name of file loader for registry """ # single paired data loader config = load_yaml(f"config/test/{data_type}_{format}.yaml") got = load.get_data_loader(data_config=config["dataset"], mode="train") expected = REGISTRY.get(category=DATA_LOADER_CLASS, key=data_type) assert isinstance(got, expected) # type: ignore
def get_data_loader(data_config: dict, split: str) -> Optional[DataLoader]: """ Return the corresponding data loader. Can't be placed in the same file of loader interfaces as it causes import cycle. :param data_config: a dictionary containing configuration for data :param split: must be train/valid/test :return: DataLoader or None, returns None if the split or dir is empty. """ if split not in KNOWN_DATA_SPLITS: raise ValueError( f"split must be one of {KNOWN_DATA_SPLITS}, got {split}") if split not in data_config: return None data_dir_paths = data_config[split].get("dir", None) if data_dir_paths is None or data_dir_paths == "": return None if isinstance(data_dir_paths, str): data_dir_paths = [data_dir_paths] # replace ~ with user home path data_dir_paths = list(map(os.path.expanduser, data_dir_paths)) for data_dir_path in data_dir_paths: if not os.path.isdir(data_dir_path): raise ValueError( f"Data directory path {data_dir_path} for split {split}" f" is not a directory or does not exist") # prepare data loader config data_loader_config = deepcopy(data_config) data_loader_config = { k: v for k, v in data_loader_config.items() if k not in KNOWN_DATA_SPLITS } data_loader_config["name"] = data_loader_config.pop("type") default_args = dict( data_dir_paths=data_dir_paths, file_loader=REGISTRY.get(category=FILE_LOADER_CLASS, key=data_config[split]["format"]), labeled=data_config[split]["labeled"], sample_label="sample" if split == "train" else "all", seed=None if split == "train" else 0, ) data_loader: DataLoader = REGISTRY.build_data_loader( config=data_loader_config, default_args=default_args) return data_loader
def get_data_loader(data_config: dict, mode: str) -> Optional[DataLoader]: """ Return the corresponding data loader. Can't be placed in the same file of loader interfaces as it causes import cycle. :param data_config: a dictionary containing configuration for data :param mode: string, must be train/valid/test :return: DataLoader or None, returns None if the data_dir_paths is empty """ assert mode in ["train", "valid", "test"], "mode must be one of train/valid/test" data_dir_paths = data_config["dir"].get(mode, None) if data_dir_paths is None or data_dir_paths == "": return None if isinstance(data_dir_paths, str): data_dir_paths = [data_dir_paths] # replace ~ with user home path data_dir_paths = list(map(os.path.expanduser, data_dir_paths)) for data_dir_path in data_dir_paths: if not os.path.isdir(data_dir_path): raise ValueError( f"Data directory path {data_dir_path} for mode {mode}" f" is not a directory or does not exist") # prepare data loader config data_loader_config = deepcopy(data_config) data_loader_config.pop("dir") data_loader_config.pop("format") data_loader_config["name"] = data_loader_config.pop("type") default_args = dict( data_dir_paths=data_dir_paths, file_loader=REGISTRY.get(category=FILE_LOADER_CLASS, key=data_config["format"]), labeled=data_config["labeled"], sample_label="sample" if mode == "train" else "all", seed=None if mode == "train" else 0, ) data_loader = REGISTRY.build_data_loader(config=data_loader_config, default_args=default_args) return data_loader
def test_get_backbone(self): # no error means the unet has been registered _ = REGISTRY.get("backbone_class", "unet")
def test_get_err(self): with pytest.raises(ValueError) as err_info: REGISTRY.get("backbone_class", "wrong_key") assert "has not been registered" in str(err_info.value)
def test_register(self): category, key, value = "backbone_class", "test_key", 0 REGISTRY.register(category=category, name=key, cls=value) assert REGISTRY._dict[(category, key)] == value assert REGISTRY.get(category, key) == value