예제 #1
0
    def test_data_loader(self, data_type: str, format: str):
        """
        Test the data loader can be successfully built.

        :param data_type: name of data loader for registry
        :param format: name of file loader for registry
        """
        # single paired data loader
        config = load_yaml(f"config/test/{data_type}_{format}.yaml")
        got = load.get_data_loader(data_config=config["dataset"], mode="train")
        expected = REGISTRY.get(category=DATA_LOADER_CLASS, key=data_type)
        assert isinstance(got, expected)  # type: ignore
예제 #2
0
def get_data_loader(data_config: dict, split: str) -> Optional[DataLoader]:
    """
    Return the corresponding data loader.

    Can't be placed in the same file of loader interfaces as it causes import cycle.

    :param data_config: a dictionary containing configuration for data
    :param split: must be train/valid/test
    :return: DataLoader or None, returns None if the split or dir is empty.
    """
    if split not in KNOWN_DATA_SPLITS:
        raise ValueError(
            f"split must be one of {KNOWN_DATA_SPLITS}, got {split}")

    if split not in data_config:
        return None
    data_dir_paths = data_config[split].get("dir", None)
    if data_dir_paths is None or data_dir_paths == "":
        return None

    if isinstance(data_dir_paths, str):
        data_dir_paths = [data_dir_paths]
    # replace ~ with user home path
    data_dir_paths = list(map(os.path.expanduser, data_dir_paths))
    for data_dir_path in data_dir_paths:
        if not os.path.isdir(data_dir_path):
            raise ValueError(
                f"Data directory path {data_dir_path} for split {split}"
                f" is not a directory or does not exist")

    # prepare data loader config
    data_loader_config = deepcopy(data_config)
    data_loader_config = {
        k: v
        for k, v in data_loader_config.items() if k not in KNOWN_DATA_SPLITS
    }
    data_loader_config["name"] = data_loader_config.pop("type")

    default_args = dict(
        data_dir_paths=data_dir_paths,
        file_loader=REGISTRY.get(category=FILE_LOADER_CLASS,
                                 key=data_config[split]["format"]),
        labeled=data_config[split]["labeled"],
        sample_label="sample" if split == "train" else "all",
        seed=None if split == "train" else 0,
    )
    data_loader: DataLoader = REGISTRY.build_data_loader(
        config=data_loader_config, default_args=default_args)
    return data_loader
예제 #3
0
def get_data_loader(data_config: dict, mode: str) -> Optional[DataLoader]:
    """
    Return the corresponding data loader.
    Can't be placed in the same file of loader interfaces as it causes import cycle.
    :param data_config: a dictionary containing configuration for data
    :param mode: string, must be train/valid/test
    :return: DataLoader or None, returns None if the data_dir_paths is empty
    """
    assert mode in ["train", "valid",
                    "test"], "mode must be one of train/valid/test"

    data_dir_paths = data_config["dir"].get(mode, None)
    if data_dir_paths is None or data_dir_paths == "":
        return None
    if isinstance(data_dir_paths, str):
        data_dir_paths = [data_dir_paths]
    # replace ~ with user home path
    data_dir_paths = list(map(os.path.expanduser, data_dir_paths))
    for data_dir_path in data_dir_paths:
        if not os.path.isdir(data_dir_path):
            raise ValueError(
                f"Data directory path {data_dir_path} for mode {mode}"
                f" is not a directory or does not exist")

    # prepare data loader config
    data_loader_config = deepcopy(data_config)
    data_loader_config.pop("dir")
    data_loader_config.pop("format")
    data_loader_config["name"] = data_loader_config.pop("type")

    default_args = dict(
        data_dir_paths=data_dir_paths,
        file_loader=REGISTRY.get(category=FILE_LOADER_CLASS,
                                 key=data_config["format"]),
        labeled=data_config["labeled"],
        sample_label="sample" if mode == "train" else "all",
        seed=None if mode == "train" else 0,
    )
    data_loader = REGISTRY.build_data_loader(config=data_loader_config,
                                             default_args=default_args)
    return data_loader
예제 #4
0
 def test_get_backbone(self):
     # no error means the unet has been registered
     _ = REGISTRY.get("backbone_class", "unet")
예제 #5
0
 def test_get_err(self):
     with pytest.raises(ValueError) as err_info:
         REGISTRY.get("backbone_class", "wrong_key")
     assert "has not been registered" in str(err_info.value)
예제 #6
0
 def test_register(self):
     category, key, value = "backbone_class", "test_key", 0
     REGISTRY.register(category=category, name=key, cls=value)
     assert REGISTRY._dict[(category, key)] == value
     assert REGISTRY.get(category, key) == value