def test_split_by_institution_invalid(splits: List[float]) -> None: df1 = pd.read_csv(full_ml_test_data_path(DATASET_CSV_FILE_NAME)) with pytest.raises(ValueError): DatasetSplits.from_institutions(df1, splits[0], splits[1], splits[2], shuffle=False)
def get_model_train_test_dataset_splits( self, dataset_df: pd.DataFrame) -> DatasetSplits: return DatasetSplits.from_institutions(df=dataset_df, proportion_train=0.6, proportion_test=0.2, proportion_val=0.2, shuffle=True)
def test_split_by_institution_exclude() -> None: """ Test if splitting data by institution correctly handles the "exclude institution" flags. """ # 40 subjects across 4 institutions test_data = { CSV_SUBJECT_HEADER: list(range(40)), CSV_INSTITUTION_HEADER: ["a", "b", "c", "d"] * 10, "other": list(range(0, 40)) } df = DataFrame(test_data) all_inst = set(df[CSV_INSTITUTION_HEADER].unique()) def check_inst_present(splits: DatasetSplits, expected: Set[str], expected_test_set: Optional[Set[str]] = None) -> None: assert expected == set(splits.train[CSV_INSTITUTION_HEADER].unique()) assert expected == set(splits.val[CSV_INSTITUTION_HEADER].unique()) assert (expected_test_set or expected) == set(splits.test[CSV_INSTITUTION_HEADER].unique()) # Normal functionality: all 4 institutions should be present in each of train, val, test splits = DatasetSplits.from_institutions(df, 0.5, 0.2, 0.3) check_inst_present(splits, all_inst) # Exclude institution "a" from all sets split1 = DatasetSplits.from_institutions(df, 0.5, 0.2, 0.3, exclude_institutions=["a"]) check_inst_present(split1, {"b", "c", "d"}) with pytest.raises(ValueError) as ex: DatasetSplits.from_institutions(df, 0.5, 0.2, 0.3, exclude_institutions=["not present"]) assert "not present" in str(ex) # Put "a" only into the test set: split2 = DatasetSplits.from_institutions(df, 0.5, 0.2, 0.3, institutions_for_test_only=["a"]) check_inst_present(split2, {"b", "c", "d"}, all_inst) with pytest.raises(ValueError) as ex: DatasetSplits.from_institutions(df, 0.5, 0.2, 0.3, institutions_for_test_only=["not present"]) assert "not present" in str(ex) forced_subjects_in_test = list(df.subject.unique())[:20] split3 = DatasetSplits.from_institutions(df, 0.5, 0.2, 0.3, subject_ids_for_test_only=forced_subjects_in_test) assert set(split3.test.subject.unique()).issuperset(forced_subjects_in_test) with pytest.raises(ValueError) as ex: DatasetSplits.from_institutions(df, 0.5, 0.2, 0.3, subject_ids_for_test_only=['999']) assert "not present" in str(ex)
def test_split_by_institution() -> None: """ Test if splitting by institution is as expected """ random.seed(0) splits = [0.5, 0.4, 0.1] expected_split_sizes_per_institution = [[5, 3, 2], [45, 36, 9]] test_data = { CSV_SUBJECT_HEADER: list(range(0, 100)), CSV_INSTITUTION_HEADER: ([0] * 10) + ([1] * 90), "other": list(range(0, 100)) } test_df = DataFrame(test_data, columns=list(test_data.keys())) dataset_splits = DatasetSplits.from_institutions( df=test_df, proportion_train=splits[0], proportion_val=splits[1], proportion_test=splits[2], shuffle=True) train_val_test = [ dataset_splits.train, dataset_splits.val, dataset_splits.test ] # Check institution ratios are as expected get_number_rows_for_institution = \ lambda _x, _i: len(_x.loc[test_df.institutionId == _i].subject.unique()) for i, inst_id in enumerate(test_df.institutionId.unique()): # noinspection PyTypeChecker for j, df in enumerate(train_val_test): np.isclose(get_number_rows_for_institution(df, inst_id), expected_split_sizes_per_institution[i][j]) # Check that there are no overlaps between the datasets assert not set.intersection(*[set(x.subject) for x in train_val_test]) # check that all of the data is persisted datasets_df = pd.concat(train_val_test) pd.testing.assert_frame_equal( datasets_df.sort_values([CSV_SUBJECT_HEADER], ascending=True), test_df)
def get_model_train_test_dataset_splits(self, dataset_df: pd.DataFrame) -> DatasetSplits: splits = DatasetSplits.from_institutions( df=dataset_df, proportion_train=0.8, proportion_test=0.1, proportion_val=0.1, shuffle=True, exclude_institutions={ "ac54f75d-c5fa-4e32-a140-485527f8e3a2", # Birmingham: 1 image "af8d9205-2ae1-422f-8b35-67ee435253e1", # OSL: 2 images "87630c93-07d6-49de-844a-3cc99fe9c323", # Brussels: 3 images "5a6ba8fe-65bc-43ec-b1fc-682c8c37e40c", # VFN: 4 images }, # These institutions have around 40 images each. The main argument in the paper will be about # keeping two of those aside as untouched test sets. # Oncoclinicas uses Siemens scanner, IOV uses a GE scanner. Most of the other images are from Toshiba # scanners. institutions_for_test_only={ # "d527557d-3b9a-45d0-ad57-692e5a199896", # AZ Groenige "85aaee5f-f5f3-4eae-b6cd-26b0070156d8", # IOV "641eda02-90c3-45ed-b8b1-2651b6a5da6c", # Oncoclinicas # "8522ccd1-ab59-4342-a2ce-7d8ad826ab4f", # UW } ) # IOV subjects not in the test set already iov_subjects = { "1ec8a7d58cadb231a0412b674731ee72da0e04ab67f2a2f009a768189bbcf691", "439bc48993c6e146c4ab573eeba35990ee843b7495dd0924dc6bd0b331e869db", "e5d338a12dfcc519787456b09072a07c6191b7140e036c52bc4d039ef3b28afd", "af7ad87cc408934cb2a65029661cb426539429a8aada6e1644a67a056c94f691", "227e859ee0bd0c4ff860dd77a20f39fe5924348ff4a4fac15dc94cea2cd07c39", "512b22856b7dbde60b4a42c348c4bee5b9efb67024fb708addcddfe1f4841288", "906f77caba56df060f5d519ae9b6572a90ac22a04560b4d561f3668e6331e3c3", "49a01ffe812b0f3e3d93334866662afb5fb33ba6dcd3cc642d4577a449000649", "ab3ed87d55da37a2a665b059b5fef54a0553656e8df51592b8c40f16facd60b9", "6eb8aeb8f822e15970d3feb64a618a9ad3de936046d84cb83d2569fbb6c70fcb"} def _swap_iov(train_val_df: pd.DataFrame, test_df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame]: """ Swap the images that are in the IOV and in the Train/Val, with those from the Test set of the same institution (to maintain the institution wise distribution of images) """ random.seed(0) # filter iov subjects that are not in the test set (as we do not want to swap them) iov_not_in_test = set([x for x in iov_subjects if x not in test_df.seriesId.unique()]) iov_train_val_subjects = train_val_df[CSV_SERIES_HEADER].isin(iov_not_in_test) iov_train_val_subjects_df = train_val_df.loc[iov_train_val_subjects] # drop IOV subjects train_val_df = train_val_df.loc[~iov_train_val_subjects] # select the same number for the same institutions from the test set (ignoring the IOV subjects that # are already in the tet set and add it to provided df for x in iov_train_val_subjects_df.institutionId.unique(): test_subs = list(test_df.loc[(test_df[CSV_INSTITUTION_HEADER] == x) & (~test_df[CSV_SERIES_HEADER] .isin( iov_subjects))].subject.unique()) num_train_val_df_subs_to_swap = len( iov_train_val_subjects_df.loc[ iov_train_val_subjects_df[CSV_INSTITUTION_HEADER] == x].subject.unique()) subjects_to_swap = random.sample(test_subs, k=num_train_val_df_subs_to_swap) # test df to swap to_swap = test_df[CSV_SUBJECT_HEADER].isin(subjects_to_swap) # swap train_val_df = pd.concat([train_val_df, test_df.loc[to_swap]]) test_df = test_df.loc[~to_swap] return train_val_df, test_df train_swap, test_swap = _swap_iov(splits.train, splits.test) val_swap, test_swap = _swap_iov(splits.val, test_swap) test_swap = pd.concat( [test_swap, dataset_df.loc[dataset_df[CSV_SERIES_HEADER].isin(iov_subjects)]]).drop_duplicates() swapped_splits = DatasetSplits( train=train_swap, val=val_swap, test=test_swap ) iov_intersection = set(swapped_splits.train.seriesId.unique()).intersection(iov_subjects) if len(iov_intersection) != 0: raise ValueError(f"Train split has IOV subjects {iov_intersection}") iov_intersection = set(swapped_splits.val.seriesId.unique()).intersection(iov_subjects) if len(iov_intersection) != 0: raise ValueError(f"Val split has IOV subjects {iov_intersection}") iov_missing = iov_subjects.difference(swapped_splits.test.seriesId.unique()) if len(iov_missing) != 0: raise ValueError(f"All IOV subjects must be in the Test split, found f{iov_missing} that are not") def _check_df_distribution(_old_df: pd.DataFrame, _new_df: pd.DataFrame) -> None: _old_df_inst = _old_df.drop_duplicates(CSV_SUBJECT_HEADER).groupby([CSV_INSTITUTION_HEADER]).groups _new_df_inst = _new_df.drop_duplicates(CSV_SUBJECT_HEADER).groupby([CSV_INSTITUTION_HEADER]).groups for k, v in _old_df_inst.items(): if len(v) != len(_new_df_inst[k]): raise ValueError(f"Expected _new_df to be length={len(v)} found {_new_df_inst[k]}") _check_df_distribution(splits.train, swapped_splits.train) _check_df_distribution(splits.val, swapped_splits.val) _check_df_distribution(splits.test, swapped_splits.test) return swapped_splits