예제 #1
0
def test_get_data_ids():
    """
    check if the get_data_ids method works as expected
    """
    dir_path = "./data/test/nifti/paired/test"
    name = "fixed_images"

    loader = NiftiFileLoader(dir_path=dir_path, name=name, grouped=False)
    got = loader.get_data_ids()
    expected = ["/case000025.nii.gz", "/case000026.nii.gz"]
    loader.close()
    assert got == expected
예제 #2
0
def test_get_data_ids():
    """
    check if the get_data_ids method works as expected
    """
    # paired
    dir_paths = ["./data/test/nifti/paired/test"]
    name = "fixed_images"
    loader = NiftiFileLoader(dir_paths=dir_paths, name=name, grouped=False)
    got = loader.get_data_ids()
    expected = [
        ("./data/test/nifti/paired/test", "case000025"),
        ("./data/test/nifti/paired/test", "case000026"),
    ]
    loader.close()
    assert got == expected

    # unpaired
    dir_paths = ["./data/test/nifti/unpaired/test"]
    name = "images"
    loader = NiftiFileLoader(dir_paths=dir_paths, name=name, grouped=False)
    got = loader.get_data_ids()
    expected = [
        ("./data/test/nifti/unpaired/test", "case000025"),
        ("./data/test/nifti/unpaired/test", "case000026"),
    ]
    loader.close()
    assert got == expected

    # grouped
    dir_paths = ["./data/test/nifti/grouped/test"]
    name = "images"
    loader = NiftiFileLoader(dir_paths=dir_paths, name=name, grouped=True)
    got = loader.get_data_ids()
    expected = [
        ("./data/test/nifti/grouped/test", "group1", "case000025"),
        ("./data/test/nifti/grouped/test", "group1", "case000026"),
    ]
    loader.close()
    assert got == expected

    # multi dirs
    dir_paths = [
        "./data/test/nifti/grouped/train", "./data/test/nifti/grouped/test"
    ]
    name = "images"
    loader = NiftiFileLoader(dir_paths=dir_paths, name=name, grouped=True)
    got = loader.get_data_ids()
    expected = [
        ("./data/test/nifti/grouped/test", "group1", "case000025"),
        ("./data/test/nifti/grouped/test", "group1", "case000026"),
        ("./data/test/nifti/grouped/train", "group1", "case000000"),
        ("./data/test/nifti/grouped/train", "group1", "case000001"),
        ("./data/test/nifti/grouped/train", "group1", "case000003"),
        ("./data/test/nifti/grouped/train", "group1", "case000008"),
        ("./data/test/nifti/grouped/train", "group2", "case000009"),
        ("./data/test/nifti/grouped/train", "group2", "case000011"),
        ("./data/test/nifti/grouped/train", "group2", "case000012"),
    ]
    loader.close()
    assert got == expected

    # wrong index for paired
    dir_paths = ["./data/test/nifti/paired/test"]
    name = "fixed_images"
    loader = NiftiFileLoader(dir_paths=dir_paths, name=name, grouped=False)
    with pytest.raises(AssertionError):
        loader.get_data(index=(0, 1))
    with pytest.raises(ValueError) as err_info:
        loader.get_data(index=[0])
    assert "must be int, or tuple" in str(err_info.value)

    # wrong index for grouped
    dir_paths = ["./data/test/nifti/grouped/test"]
    name = "images"
    loader = NiftiFileLoader(dir_paths=dir_paths, name=name, grouped=True)
    with pytest.raises(AssertionError):
        # negative group_index
        loader.get_data(index=(-1, 1))
    with pytest.raises(IndexError):
        # out of range group_index
        loader.get_data(index=(32, 1))
    with pytest.raises(AssertionError):
        # negative in_group_data_index
        loader.get_data(index=(0, -1))
    with pytest.raises(IndexError):
        # out of range in_group_data_index
        loader.get_data(index=(0, 32))