Esempio n. 1
0
def test_split_by_institution_invalid(splits: List[float]) -> None:
    df1 = pd.read_csv(full_ml_test_data_path(DATASET_CSV_FILE_NAME))
    with pytest.raises(ValueError):
        DatasetSplits.from_institutions(df1,
                                        splits[0],
                                        splits[1],
                                        splits[2],
                                        shuffle=False)
Esempio n. 2
0
 def get_model_train_test_dataset_splits(
         self, dataset_df: pd.DataFrame) -> DatasetSplits:
     return DatasetSplits.from_institutions(df=dataset_df,
                                            proportion_train=0.6,
                                            proportion_test=0.2,
                                            proportion_val=0.2,
                                            shuffle=True)
def test_split_by_institution_exclude() -> None:
    """
    Test if splitting data by institution correctly handles the "exclude institution" flags.
    """
    # 40 subjects across 4 institutions
    test_data = {
        CSV_SUBJECT_HEADER: list(range(40)),
        CSV_INSTITUTION_HEADER: ["a", "b", "c", "d"] * 10,
        "other": list(range(0, 40))
    }
    df = DataFrame(test_data)
    all_inst = set(df[CSV_INSTITUTION_HEADER].unique())

    def check_inst_present(splits: DatasetSplits, expected: Set[str],
                           expected_test_set: Optional[Set[str]] = None) -> None:
        assert expected == set(splits.train[CSV_INSTITUTION_HEADER].unique())
        assert expected == set(splits.val[CSV_INSTITUTION_HEADER].unique())
        assert (expected_test_set or expected) == set(splits.test[CSV_INSTITUTION_HEADER].unique())

    # Normal functionality: all 4 institutions should be present in each of train, val, test
    splits = DatasetSplits.from_institutions(df, 0.5, 0.2, 0.3)
    check_inst_present(splits, all_inst)
    # Exclude institution "a" from all sets
    split1 = DatasetSplits.from_institutions(df, 0.5, 0.2, 0.3, exclude_institutions=["a"])
    check_inst_present(split1, {"b", "c", "d"})

    with pytest.raises(ValueError) as ex:
        DatasetSplits.from_institutions(df, 0.5, 0.2, 0.3, exclude_institutions=["not present"])
    assert "not present" in str(ex)

    # Put "a" only into the test set:
    split2 = DatasetSplits.from_institutions(df, 0.5, 0.2, 0.3, institutions_for_test_only=["a"])
    check_inst_present(split2, {"b", "c", "d"}, all_inst)

    with pytest.raises(ValueError) as ex:
        DatasetSplits.from_institutions(df, 0.5, 0.2, 0.3, institutions_for_test_only=["not present"])
    assert "not present" in str(ex)

    forced_subjects_in_test = list(df.subject.unique())[:20]
    split3 = DatasetSplits.from_institutions(df, 0.5, 0.2, 0.3, subject_ids_for_test_only=forced_subjects_in_test)
    assert set(split3.test.subject.unique()).issuperset(forced_subjects_in_test)

    with pytest.raises(ValueError) as ex:
        DatasetSplits.from_institutions(df, 0.5, 0.2, 0.3, subject_ids_for_test_only=['999'])
    assert "not present" in str(ex)
Esempio n. 4
0
def test_split_by_institution() -> None:
    """
    Test if splitting by institution is as expected
    """
    random.seed(0)
    splits = [0.5, 0.4, 0.1]
    expected_split_sizes_per_institution = [[5, 3, 2], [45, 36, 9]]
    test_data = {
        CSV_SUBJECT_HEADER: list(range(0, 100)),
        CSV_INSTITUTION_HEADER: ([0] * 10) + ([1] * 90),
        "other": list(range(0, 100))
    }

    test_df = DataFrame(test_data, columns=list(test_data.keys()))
    dataset_splits = DatasetSplits.from_institutions(
        df=test_df,
        proportion_train=splits[0],
        proportion_val=splits[1],
        proportion_test=splits[2],
        shuffle=True)

    train_val_test = [
        dataset_splits.train, dataset_splits.val, dataset_splits.test
    ]
    # Check institution ratios are as expected
    get_number_rows_for_institution = \
        lambda _x, _i: len(_x.loc[test_df.institutionId == _i].subject.unique())

    for i, inst_id in enumerate(test_df.institutionId.unique()):
        # noinspection PyTypeChecker
        for j, df in enumerate(train_val_test):
            np.isclose(get_number_rows_for_institution(df, inst_id),
                       expected_split_sizes_per_institution[i][j])

    # Check that there are no overlaps between the datasets
    assert not set.intersection(*[set(x.subject) for x in train_val_test])

    # check that all of the data is persisted
    datasets_df = pd.concat(train_val_test)
    pd.testing.assert_frame_equal(
        datasets_df.sort_values([CSV_SUBJECT_HEADER], ascending=True), test_df)
    def get_model_train_test_dataset_splits(self, dataset_df: pd.DataFrame) -> DatasetSplits:
        splits = DatasetSplits.from_institutions(
            df=dataset_df,
            proportion_train=0.8,
            proportion_test=0.1,
            proportion_val=0.1,
            shuffle=True,
            exclude_institutions={
                "ac54f75d-c5fa-4e32-a140-485527f8e3a2",  # Birmingham: 1 image
                "af8d9205-2ae1-422f-8b35-67ee435253e1",  # OSL: 2 images
                "87630c93-07d6-49de-844a-3cc99fe9c323",  # Brussels: 3 images
                "5a6ba8fe-65bc-43ec-b1fc-682c8c37e40c",  # VFN: 4 images
            },
            # These institutions have around 40 images each. The main argument in the paper will be about
            # keeping two of those aside as untouched test sets.
            # Oncoclinicas uses Siemens scanner, IOV uses a GE scanner. Most of the other images are from Toshiba
            # scanners.
            institutions_for_test_only={
                # "d527557d-3b9a-45d0-ad57-692e5a199896",  # AZ Groenige
                "85aaee5f-f5f3-4eae-b6cd-26b0070156d8",  # IOV
                "641eda02-90c3-45ed-b8b1-2651b6a5da6c",  # Oncoclinicas
                # "8522ccd1-ab59-4342-a2ce-7d8ad826ab4f",  # UW
            }
        )

        # IOV subjects not in the test set already
        iov_subjects = {
            "1ec8a7d58cadb231a0412b674731ee72da0e04ab67f2a2f009a768189bbcf691",
            "439bc48993c6e146c4ab573eeba35990ee843b7495dd0924dc6bd0b331e869db",
            "e5d338a12dfcc519787456b09072a07c6191b7140e036c52bc4d039ef3b28afd",
            "af7ad87cc408934cb2a65029661cb426539429a8aada6e1644a67a056c94f691",
            "227e859ee0bd0c4ff860dd77a20f39fe5924348ff4a4fac15dc94cea2cd07c39",
            "512b22856b7dbde60b4a42c348c4bee5b9efb67024fb708addcddfe1f4841288",
            "906f77caba56df060f5d519ae9b6572a90ac22a04560b4d561f3668e6331e3c3",
            "49a01ffe812b0f3e3d93334866662afb5fb33ba6dcd3cc642d4577a449000649",
            "ab3ed87d55da37a2a665b059b5fef54a0553656e8df51592b8c40f16facd60b9",
            "6eb8aeb8f822e15970d3feb64a618a9ad3de936046d84cb83d2569fbb6c70fcb"}

        def _swap_iov(train_val_df: pd.DataFrame, test_df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame]:
            """
            Swap the images that are in the IOV and in the Train/Val, with those from the Test set
            of the same institution (to maintain the institution wise distribution of images)
            """
            random.seed(0)
            # filter iov subjects that are not in the test set (as we do not want to swap them)
            iov_not_in_test = set([x for x in iov_subjects if x not in test_df.seriesId.unique()])

            iov_train_val_subjects = train_val_df[CSV_SERIES_HEADER].isin(iov_not_in_test)
            iov_train_val_subjects_df = train_val_df.loc[iov_train_val_subjects]
            # drop IOV subjects
            train_val_df = train_val_df.loc[~iov_train_val_subjects]
            # select the same number for the same institutions from the test set (ignoring the IOV subjects that
            # are already in the tet set and add it to provided df
            for x in iov_train_val_subjects_df.institutionId.unique():
                test_subs = list(test_df.loc[(test_df[CSV_INSTITUTION_HEADER] == x) & (~test_df[CSV_SERIES_HEADER]
                                                                                       .isin(
                    iov_subjects))].subject.unique())
                num_train_val_df_subs_to_swap = len(
                    iov_train_val_subjects_df.loc[
                        iov_train_val_subjects_df[CSV_INSTITUTION_HEADER] == x].subject.unique())
                subjects_to_swap = random.sample(test_subs, k=num_train_val_df_subs_to_swap)
                # test df to swap
                to_swap = test_df[CSV_SUBJECT_HEADER].isin(subjects_to_swap)
                # swap
                train_val_df = pd.concat([train_val_df, test_df.loc[to_swap]])
                test_df = test_df.loc[~to_swap]

            return train_val_df, test_df

        train_swap, test_swap = _swap_iov(splits.train, splits.test)
        val_swap, test_swap = _swap_iov(splits.val, test_swap)
        test_swap = pd.concat(
            [test_swap, dataset_df.loc[dataset_df[CSV_SERIES_HEADER].isin(iov_subjects)]]).drop_duplicates()

        swapped_splits = DatasetSplits(
            train=train_swap,
            val=val_swap,
            test=test_swap
        )

        iov_intersection = set(swapped_splits.train.seriesId.unique()).intersection(iov_subjects)
        if len(iov_intersection) != 0:
            raise ValueError(f"Train split has IOV subjects {iov_intersection}")
        iov_intersection = set(swapped_splits.val.seriesId.unique()).intersection(iov_subjects)
        if len(iov_intersection) != 0:
            raise ValueError(f"Val split has IOV subjects {iov_intersection}")

        iov_missing = iov_subjects.difference(swapped_splits.test.seriesId.unique())
        if len(iov_missing) != 0:
            raise ValueError(f"All IOV subjects must be in the Test split, found f{iov_missing} that are not")

        def _check_df_distribution(_old_df: pd.DataFrame, _new_df: pd.DataFrame) -> None:
            _old_df_inst = _old_df.drop_duplicates(CSV_SUBJECT_HEADER).groupby([CSV_INSTITUTION_HEADER]).groups
            _new_df_inst = _new_df.drop_duplicates(CSV_SUBJECT_HEADER).groupby([CSV_INSTITUTION_HEADER]).groups
            for k, v in _old_df_inst.items():
                if len(v) != len(_new_df_inst[k]):
                    raise ValueError(f"Expected _new_df to be length={len(v)} found {_new_df_inst[k]}")

        _check_df_distribution(splits.train, swapped_splits.train)
        _check_df_distribution(splits.val, swapped_splits.val)
        _check_df_distribution(splits.test, swapped_splits.test)

        return swapped_splits