Exemple #1
0
def test_get_subject_ranges_for_splits() -> None:
    def _check_at_least_one(x: Dict[ModelExecutionMode, Set[str]]) -> None:
        assert all(len(x[mode]) >= 1 for mode in x.keys())

    proportions = [0.5, 0.4, 0.1]

    splits = DatasetSplits.get_subject_ranges_for_splits(['1', '2', '3'],
                                                         proportions[0],
                                                         proportions[1],
                                                         proportions[2])
    _check_at_least_one(splits)

    splits = DatasetSplits.get_subject_ranges_for_splits(['1'], proportions[0],
                                                         proportions[1],
                                                         proportions[2])
    assert splits[ModelExecutionMode.TRAIN] == {'1'}

    population = list(map(str, range(100)))
    splits = DatasetSplits.get_subject_ranges_for_splits(
        population, proportions[0], proportions[1], proportions[2])
    _check_at_least_one(splits)
    assert all([
        np.isclose(len(splits[mode]) / len(population), proportions[i])
        for i, mode in enumerate(splits.keys())
    ])
Exemple #2
0
def test_split_by_subject_ids_invalid(splits: List[List[str]]) -> None:
    df1 = pd.read_csv(full_ml_test_data_path(DATASET_CSV_FILE_NAME), dtype=str)
    with pytest.raises(ValueError):
        DatasetSplits.from_subject_ids(df1,
                                       train_ids=splits[0],
                                       val_ids=splits[1],
                                       test_ids=splits[2])
Exemple #3
0
    def get_model_train_test_dataset_splits(self, dataset_df: pd.DataFrame) -> DatasetSplits:
        if self.test_set_ids_csv:
            test_set_ids_csv = self.local_dataset / self.test_set_ids_csv
            test_series = pd.read_csv(test_set_ids_csv).series

            all_series = dataset_df.series.values
            check_all_test_series = all(test_series.isin(all_series))
            if not check_all_test_series:
                raise ValueError(f"Not all test series from {test_set_ids_csv} were found in the dataset.")

            test_set_subjects = dataset_df[dataset_df.series.isin(test_series)].subject.values
            train_and_val_series = dataset_df[~dataset_df.subject.isin(test_set_subjects)].series.values
            random.seed(42)
            random.shuffle(train_and_val_series)
            num_val_samples = math.floor(len(train_and_val_series) / 9)
            val_series = train_and_val_series[:num_val_samples]
            train_series = train_and_val_series[num_val_samples:]

            logging.info(f"Dropped {len(all_series) - (len(test_series) + len(train_and_val_series))} series "
                         f"due to subject overlap with test set.")
            return DatasetSplits.from_subject_ids(dataset_df,
                                                  train_ids=train_series,
                                                  val_ids=val_series,
                                                  test_ids=test_series,
                                                  subject_column="series",
                                                  group_column="subject")
        else:
            return DatasetSplits.from_proportions(dataset_df,
                                                  proportion_train=0.8,
                                                  proportion_val=0.1,
                                                  proportion_test=0.1,
                                                  subject_column="series",
                                                  group_column="subject",
                                                  shuffle=True)
Exemple #4
0
def test_split_by_institution_invalid(splits: List[float]) -> None:
    df1 = pd.read_csv(full_ml_test_data_path(DATASET_CSV_FILE_NAME))
    with pytest.raises(ValueError):
        DatasetSplits.from_institutions(df1,
                                        splits[0],
                                        splits[1],
                                        splits[2],
                                        shuffle=False)
Exemple #5
0
 def get_model_train_test_dataset_splits(self, dataset_df: pd.DataFrame) -> DatasetSplits:
     return DatasetSplits.from_proportions(
         df=dataset_df,
         proportion_train=0.7,
         proportion_test=0.2,
         proportion_val=0.1,
     )
 def get_model_train_test_dataset_splits(self, dataset_df: pd.DataFrame) -> DatasetSplits:
     return DatasetSplits.from_subject_ids(
         df=dataset_df,
         train_ids=['0', '1'],
         test_ids=['5'],
         val_ids=['2']
     )
 def get_model_train_test_dataset_splits(
         self, dataset_df: pd.DataFrame) -> DatasetSplits:
     return DatasetSplits.from_proportions(dataset_df,
                                           proportion_train=0.8,
                                           proportion_val=0.05,
                                           proportion_test=0.15,
                                           random_seed=0)
Exemple #8
0
 def get_model_train_test_dataset_splits(
         self, dataset_df: pd.DataFrame) -> DatasetSplits:
     return DatasetSplits.from_institutions(df=dataset_df,
                                            proportion_train=0.6,
                                            proportion_test=0.2,
                                            proportion_val=0.2,
                                            shuffle=True)
Exemple #9
0
def test_restrict_subjects3() -> None:
    test_df, test_ids, train_ids, val_ids = _get_test_df()
    splits = DatasetSplits.from_subject_ids(test_df, train_ids, test_ids,
                                            val_ids).restrict_subjects(",0,+")
    assert len(splits.train.subject.unique()) == len(train_ids)
    assert len(splits.val.subject.unique()) == 0
    assert len(splits.test.subject.unique()) == len(test_ids) + len(val_ids)
Exemple #10
0
def test_split_by_subject_ids() -> None:
    test_df, test_ids, train_ids, val_ids = _get_test_df()
    splits = DatasetSplits.from_subject_ids(test_df, train_ids, test_ids,
                                            val_ids)

    for x, y in zip([splits.train, splits.test, splits.val],
                    [train_ids, test_ids, val_ids]):
        pd.testing.assert_frame_equal(x, test_df[test_df.subject.isin(y)])
 def get_model_train_test_dataset_splits(
         self, dataset_df: pd.DataFrame) -> DatasetSplits:
     return DatasetSplits.from_subject_ids(
         df=dataset_df,
         train_ids=[1, 2, 3],
         val_ids=[4, 5],
         test_ids=[6],
     )
def test_split_by_institution_exclude() -> None:
    """
    Test if splitting data by institution correctly handles the "exclude institution" flags.
    """
    # 40 subjects across 4 institutions
    test_data = {
        CSV_SUBJECT_HEADER: list(range(40)),
        CSV_INSTITUTION_HEADER: ["a", "b", "c", "d"] * 10,
        "other": list(range(0, 40))
    }
    df = DataFrame(test_data)
    all_inst = set(df[CSV_INSTITUTION_HEADER].unique())

    def check_inst_present(splits: DatasetSplits, expected: Set[str],
                           expected_test_set: Optional[Set[str]] = None) -> None:
        assert expected == set(splits.train[CSV_INSTITUTION_HEADER].unique())
        assert expected == set(splits.val[CSV_INSTITUTION_HEADER].unique())
        assert (expected_test_set or expected) == set(splits.test[CSV_INSTITUTION_HEADER].unique())

    # Normal functionality: all 4 institutions should be present in each of train, val, test
    splits = DatasetSplits.from_institutions(df, 0.5, 0.2, 0.3)
    check_inst_present(splits, all_inst)
    # Exclude institution "a" from all sets
    split1 = DatasetSplits.from_institutions(df, 0.5, 0.2, 0.3, exclude_institutions=["a"])
    check_inst_present(split1, {"b", "c", "d"})

    with pytest.raises(ValueError) as ex:
        DatasetSplits.from_institutions(df, 0.5, 0.2, 0.3, exclude_institutions=["not present"])
    assert "not present" in str(ex)

    # Put "a" only into the test set:
    split2 = DatasetSplits.from_institutions(df, 0.5, 0.2, 0.3, institutions_for_test_only=["a"])
    check_inst_present(split2, {"b", "c", "d"}, all_inst)

    with pytest.raises(ValueError) as ex:
        DatasetSplits.from_institutions(df, 0.5, 0.2, 0.3, institutions_for_test_only=["not present"])
    assert "not present" in str(ex)

    forced_subjects_in_test = list(df.subject.unique())[:20]
    split3 = DatasetSplits.from_institutions(df, 0.5, 0.2, 0.3, subject_ids_for_test_only=forced_subjects_in_test)
    assert set(split3.test.subject.unique()).issuperset(forced_subjects_in_test)

    with pytest.raises(ValueError) as ex:
        DatasetSplits.from_institutions(df, 0.5, 0.2, 0.3, subject_ids_for_test_only=['999'])
    assert "not present" in str(ex)
 def get_model_train_test_dataset_splits(
         self, dataset_df: pd.DataFrame) -> DatasetSplits:
     return DatasetSplits.from_proportions(
         df=dataset_df,
         proportion_train=0.7,
         proportion_test=0.2,
         proportion_val=0.1,
         random_seed=1,
         subject_column=self.subject_column)
 def get_cross_validation_dataset_splits(self, dataset_split: DatasetSplits) -> DatasetSplits:
     """
     When running cross validation, this method returns the dataset split that should be used for the
     currently executed cross validation split.
     :param dataset_split: The full dataset, split into training, validation and test section.
     :return: The dataset split with training and validation sections shuffled according to the current
     cross validation index.
     """
     splits = dataset_split.get_k_fold_cross_validation_splits(self.number_of_cross_validation_splits)
     return splits[self.cross_validation_split_index]
def test_get_k_fold_cross_validation_splits() -> None:
    # check the dataset splits have deterministic randomness
    for i in range(2):
        test_df, test_ids, train_ids, val_ids = _get_test_df()
        splits = DatasetSplits.from_subject_ids(test_df, train_ids, test_ids, val_ids)
        folds = splits.get_k_fold_cross_validation_splits(n_splits=5)
        assert len(folds) == 5
        assert all([x.test.equals(splits.test) for x in folds])
        assert all(
            [len(set(list(x.train.subject.unique()) + list(x.test.subject.unique()) + list(x.val.subject.unique()))
                 .difference(set(test_df.subject.unique()))) == 0 for x in folds])
Exemple #16
0
def test_grouped_splits(group_column: str) -> None:
    test_df = _get_test_df()[0]
    proportions = [0.5, 0.4, 0.1]
    splits = DatasetSplits.from_proportions(test_df,
                                            proportions[0],
                                            proportions[1],
                                            proportions[2],
                                            group_column=group_column)
    _check_is_partition(test_df, [splits.train, splits.test, splits.val],
                        CSV_SUBJECT_HEADER)
    _check_is_partition(test_df, [splits.train, splits.test, splits.val],
                        group_column)
Exemple #17
0
    def get_model_train_test_dataset_splits(
            self, dataset_df: pd.DataFrame) -> DatasetSplits:
        # The first 24 subject IDs are the designated test subjects in this dataset.
        test = list(range(0, 24))
        train_val = list(
            dataset_df[~dataset_df.subject.isin(test)].subject.unique())

        val = numpy.random.choice(train_val,
                                  int(len(train_val) * 0.1),
                                  replace=False)
        train = [x for x in train_val if x not in val]

        return DatasetSplits.from_subject_ids(df=dataset_df,
                                              test_ids=test,
                                              val_ids=val,
                                              train_ids=train)
Exemple #18
0
def test_split_by_institution() -> None:
    """
    Test if splitting by institution is as expected
    """
    random.seed(0)
    splits = [0.5, 0.4, 0.1]
    expected_split_sizes_per_institution = [[5, 3, 2], [45, 36, 9]]
    test_data = {
        CSV_SUBJECT_HEADER: list(range(0, 100)),
        CSV_INSTITUTION_HEADER: ([0] * 10) + ([1] * 90),
        "other": list(range(0, 100))
    }

    test_df = DataFrame(test_data, columns=list(test_data.keys()))
    dataset_splits = DatasetSplits.from_institutions(
        df=test_df,
        proportion_train=splits[0],
        proportion_val=splits[1],
        proportion_test=splits[2],
        shuffle=True)

    train_val_test = [
        dataset_splits.train, dataset_splits.val, dataset_splits.test
    ]
    # Check institution ratios are as expected
    get_number_rows_for_institution = \
        lambda _x, _i: len(_x.loc[test_df.institutionId == _i].subject.unique())

    for i, inst_id in enumerate(test_df.institutionId.unique()):
        # noinspection PyTypeChecker
        for j, df in enumerate(train_val_test):
            np.isclose(get_number_rows_for_institution(df, inst_id),
                       expected_split_sizes_per_institution[i][j])

    # Check that there are no overlaps between the datasets
    assert not set.intersection(*[set(x.subject) for x in train_val_test])

    # check that all of the data is persisted
    datasets_df = pd.concat(train_val_test)
    pd.testing.assert_frame_equal(
        datasets_df.sort_values([CSV_SUBJECT_HEADER], ascending=True), test_df)
Exemple #19
0
def test_grouped_k_fold_cross_validation_splits(group_column: str) -> None:
    test_df = _get_test_df()[0]
    proportions = [0.5, 0.4, 0.1]
    splits = DatasetSplits.from_proportions(test_df,
                                            proportions[0],
                                            proportions[1],
                                            proportions[2],
                                            group_column=group_column)

    n_splits = 7  # mutually prime with numbers of subjects and groups
    val_folds = []
    for fold in splits.get_k_fold_cross_validation_splits(n_splits):
        _check_is_partition(test_df, [fold.train, fold.test, fold.val],
                            CSV_SUBJECT_HEADER)
        _check_is_partition(test_df, [fold.train, fold.test, fold.val],
                            group_column)
        assert fold.test.equals(splits.test)
        val_folds.append(fold.val)

    # ensure validation folds partition the original train+val set
    train_val = pd.concat([splits.train, splits.val])
    _check_is_partition(train_val, val_folds, CSV_SUBJECT_HEADER)
    _check_is_partition(train_val, val_folds, group_column)
Exemple #20
0
def test_parse_and_check_restriction_pattern() -> None:
    assert DatasetSplits.parse_restriction_pattern("") == (None, None, None)
    assert DatasetSplits.parse_restriction_pattern("42") == (42, 42, 42)
    assert DatasetSplits.parse_restriction_pattern("1,2,3") == (1, 2, 3)
    assert DatasetSplits.parse_restriction_pattern("1,,3") == (1, None, 3)
    assert DatasetSplits.parse_restriction_pattern(",,3") == (None, None, 3)
    assert DatasetSplits.parse_restriction_pattern("+,0,3") == (sys.maxsize, 0,
                                                                3)
    assert DatasetSplits.parse_restriction_pattern("1,2,+") == (1, 2,
                                                                sys.maxsize)
    with pytest.raises(ValueError):
        # Neither 1 nor 3 fields
        DatasetSplits.parse_restriction_pattern("1,2")
    with pytest.raises(ValueError):
        # Neither 1 nor 3 fields
        DatasetSplits.parse_restriction_pattern("1,2,3,4")
    with pytest.raises(ValueError):
        # Equivalent to "+,+,+", and we only allow one "+" field.
        DatasetSplits.parse_restriction_pattern("+")
    with pytest.raises(ValueError):
        # This would mean "move the training set to validation AND to test".
        DatasetSplits.parse_restriction_pattern("0,+,+")
Exemple #21
0
 def get_model_train_test_dataset_splits(
         self, dataset_df: pd.DataFrame) -> DatasetSplits:
     return DatasetSplits(
         train=dataset_df[dataset_df.subject.isin(['1'])],
         test=dataset_df[dataset_df.subject.isin(['2', '3'])],
         val=dataset_df[dataset_df.subject.isin(['4'])])
Exemple #22
0
 def get_model_train_test_dataset_splits(
         self, dataset_df: pd.DataFrame) -> DatasetSplits:
     return DatasetSplits(
         train=dataset_df[dataset_df.subject.isin(self.train_subject_ids)],
         test=dataset_df[dataset_df.subject.isin(self.test_subject_ids)],
         val=dataset_df[dataset_df.subject.isin(self.val_subject_ids)])
Exemple #23
0
def test_sequence_dataset_all(test_output_dirs: OutputFolderForTests) -> None:
    """
    Check that the sequence dataset works end-to-end, including applying the right standardization.
    """
    csv_string = """subject,seq,value,scalar1,scalar2,META,BETA
S1,0,False,0,0,M1,B1
S1,1,True,1,10,M2,B2
S2,0,False,2,20,M2,B1
S3,0,True,3,30,M1,B1
S4,0,True,4,40,M2,B1
"""
    csv_path = create_dataset_csv_file(csv_string, test_output_dirs.root_dir)
    config = SequenceModelBase(local_dataset=csv_path,
                               image_file_column=None,
                               label_value_column="value",
                               numerical_columns=["scalar1", "scalar2"],
                               sequence_target_positions=[0],
                               categorical_columns=["META", "BETA"],
                               sequence_column="seq",
                               num_dataload_workers=0,
                               train_batch_size=2,
                               should_validate=False,
                               shuffle=False)
    config.read_dataset_if_needed()
    df = config.dataset_data_frame
    assert df is not None
    df1 = df[df.subject.isin(["S1", "S2"])]
    df2 = df[df.subject == "S3"]
    df3 = df[df.subject == "S4"]
    splits = DatasetSplits(train=df1, val=df2, test=df3)
    with mock.patch.object(SequenceModelBase,
                           'get_model_train_test_dataset_splits',
                           return_value=splits):
        train_val_loaders = config.create_data_loaders()
        # Expected feature mean: Mean of the training data (0, 0), (1, 10), (2, 20) = (1, 10)
        # Expected (biased corrected) std estimate: Std of (0, 0), (1, 10), (2, 20) = (1, 10)
        feature_stats = config.get_torch_dataset_for_inference(
            ModelExecutionMode.TRAIN).feature_statistics
        assert feature_stats is not None
        assert_tensors_equal(feature_stats.mean, [1, 10])
        assert_tensors_equal(feature_stats.std, [1, 10])

        train_items = list(
            ClassificationItemSequence.from_minibatch(b)
            for b in train_val_loaders[ModelExecutionMode.TRAIN])
        assert len(
            train_items
        ) == 1, "2 items in training set with batch size of 2 should return 1 minibatch"
        assert len(train_items[0]) == 2
        assert train_items[0][0].id == "S1"
        assert_tensors_equal(
            train_items[0][0].items[0].get_all_non_imaging_features(),
            [-1., -1., 1., 0., 1., 0.])
        assert_tensors_equal(
            train_items[0][0].items[1].get_all_non_imaging_features(),
            [0., 0., 0., 1., 0., 1.])
        assert train_items[0][1].id == "S2"
        assert_tensors_equal(
            train_items[0][1].items[0].get_all_non_imaging_features(),
            [1., 1., 0., 1., 1., 0.])
        val_items = list(
            ClassificationItemSequence.from_minibatch(b)
            for b in train_val_loaders[ModelExecutionMode.VAL])
        assert len(val_items) == 1
        assert len(val_items[0]) == 1
        assert val_items[0][0].id == "S3"
        # Items in the validation set should be normalized using the mean and std on the training data.
        # Hence, the non-image features (3, 30) should turn into (2, 2)
        assert_tensors_equal(
            val_items[0][0].items[0].get_all_non_imaging_features(),
            [2., 2., 1., 0., 1., 0.])

        # Check that the test set is also normalized correctly using the training mean and std.
        test_items = list(
            ClassificationItemSequence(**b) for b in
            config.get_torch_dataset_for_inference(ModelExecutionMode.TEST))
        assert test_items[0].id == "S4"
        # Check Non-image features of (4, 40)
        assert_tensors_equal(
            test_items[0].items[0].get_all_non_imaging_features(),
            [3., 3., 0., 1., 1., 0.])
    def get_model_train_test_dataset_splits(self, dataset_df: pd.DataFrame) -> DatasetSplits:
        splits = DatasetSplits.from_institutions(
            df=dataset_df,
            proportion_train=0.8,
            proportion_test=0.1,
            proportion_val=0.1,
            shuffle=True,
            exclude_institutions={
                "ac54f75d-c5fa-4e32-a140-485527f8e3a2",  # Birmingham: 1 image
                "af8d9205-2ae1-422f-8b35-67ee435253e1",  # OSL: 2 images
                "87630c93-07d6-49de-844a-3cc99fe9c323",  # Brussels: 3 images
                "5a6ba8fe-65bc-43ec-b1fc-682c8c37e40c",  # VFN: 4 images
            },
            # These institutions have around 40 images each. The main argument in the paper will be about
            # keeping two of those aside as untouched test sets.
            # Oncoclinicas uses Siemens scanner, IOV uses a GE scanner. Most of the other images are from Toshiba
            # scanners.
            institutions_for_test_only={
                # "d527557d-3b9a-45d0-ad57-692e5a199896",  # AZ Groenige
                "85aaee5f-f5f3-4eae-b6cd-26b0070156d8",  # IOV
                "641eda02-90c3-45ed-b8b1-2651b6a5da6c",  # Oncoclinicas
                # "8522ccd1-ab59-4342-a2ce-7d8ad826ab4f",  # UW
            }
        )

        # IOV subjects not in the test set already
        iov_subjects = {
            "1ec8a7d58cadb231a0412b674731ee72da0e04ab67f2a2f009a768189bbcf691",
            "439bc48993c6e146c4ab573eeba35990ee843b7495dd0924dc6bd0b331e869db",
            "e5d338a12dfcc519787456b09072a07c6191b7140e036c52bc4d039ef3b28afd",
            "af7ad87cc408934cb2a65029661cb426539429a8aada6e1644a67a056c94f691",
            "227e859ee0bd0c4ff860dd77a20f39fe5924348ff4a4fac15dc94cea2cd07c39",
            "512b22856b7dbde60b4a42c348c4bee5b9efb67024fb708addcddfe1f4841288",
            "906f77caba56df060f5d519ae9b6572a90ac22a04560b4d561f3668e6331e3c3",
            "49a01ffe812b0f3e3d93334866662afb5fb33ba6dcd3cc642d4577a449000649",
            "ab3ed87d55da37a2a665b059b5fef54a0553656e8df51592b8c40f16facd60b9",
            "6eb8aeb8f822e15970d3feb64a618a9ad3de936046d84cb83d2569fbb6c70fcb"}

        def _swap_iov(train_val_df: pd.DataFrame, test_df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame]:
            """
            Swap the images that are in the IOV and in the Train/Val, with those from the Test set
            of the same institution (to maintain the institution wise distribution of images)
            """
            random.seed(0)
            # filter iov subjects that are not in the test set (as we do not want to swap them)
            iov_not_in_test = set([x for x in iov_subjects if x not in test_df.seriesId.unique()])

            iov_train_val_subjects = train_val_df[CSV_SERIES_HEADER].isin(iov_not_in_test)
            iov_train_val_subjects_df = train_val_df.loc[iov_train_val_subjects]
            # drop IOV subjects
            train_val_df = train_val_df.loc[~iov_train_val_subjects]
            # select the same number for the same institutions from the test set (ignoring the IOV subjects that
            # are already in the tet set and add it to provided df
            for x in iov_train_val_subjects_df.institutionId.unique():
                test_subs = list(test_df.loc[(test_df[CSV_INSTITUTION_HEADER] == x) & (~test_df[CSV_SERIES_HEADER]
                                                                                       .isin(
                    iov_subjects))].subject.unique())
                num_train_val_df_subs_to_swap = len(
                    iov_train_val_subjects_df.loc[
                        iov_train_val_subjects_df[CSV_INSTITUTION_HEADER] == x].subject.unique())
                subjects_to_swap = random.sample(test_subs, k=num_train_val_df_subs_to_swap)
                # test df to swap
                to_swap = test_df[CSV_SUBJECT_HEADER].isin(subjects_to_swap)
                # swap
                train_val_df = pd.concat([train_val_df, test_df.loc[to_swap]])
                test_df = test_df.loc[~to_swap]

            return train_val_df, test_df

        train_swap, test_swap = _swap_iov(splits.train, splits.test)
        val_swap, test_swap = _swap_iov(splits.val, test_swap)
        test_swap = pd.concat(
            [test_swap, dataset_df.loc[dataset_df[CSV_SERIES_HEADER].isin(iov_subjects)]]).drop_duplicates()

        swapped_splits = DatasetSplits(
            train=train_swap,
            val=val_swap,
            test=test_swap
        )

        iov_intersection = set(swapped_splits.train.seriesId.unique()).intersection(iov_subjects)
        if len(iov_intersection) != 0:
            raise ValueError(f"Train split has IOV subjects {iov_intersection}")
        iov_intersection = set(swapped_splits.val.seriesId.unique()).intersection(iov_subjects)
        if len(iov_intersection) != 0:
            raise ValueError(f"Val split has IOV subjects {iov_intersection}")

        iov_missing = iov_subjects.difference(swapped_splits.test.seriesId.unique())
        if len(iov_missing) != 0:
            raise ValueError(f"All IOV subjects must be in the Test split, found f{iov_missing} that are not")

        def _check_df_distribution(_old_df: pd.DataFrame, _new_df: pd.DataFrame) -> None:
            _old_df_inst = _old_df.drop_duplicates(CSV_SUBJECT_HEADER).groupby([CSV_INSTITUTION_HEADER]).groups
            _new_df_inst = _new_df.drop_duplicates(CSV_SUBJECT_HEADER).groupby([CSV_INSTITUTION_HEADER]).groups
            for k, v in _old_df_inst.items():
                if len(v) != len(_new_df_inst[k]):
                    raise ValueError(f"Expected _new_df to be length={len(v)} found {_new_df_inst[k]}")

        _check_df_distribution(splits.train, swapped_splits.train)
        _check_df_distribution(splits.val, swapped_splits.val)
        _check_df_distribution(splits.test, swapped_splits.test)

        return swapped_splits
Exemple #25
0
 def get_model_train_test_dataset_splits(self, dataset_df: pd.DataFrame) -> DatasetSplits:
     return DatasetSplits(train=dataset_df[dataset_df.subject.isin([1, 2])],
                          test=dataset_df[dataset_df.subject.isin([3, 4])],
                          val=dataset_df[dataset_df.subject.isin([5, 6])])
Exemple #26
0
 def get_model_train_test_dataset_splits(
         self, dataset_df: pd.DataFrame) -> DatasetSplits:
     return DatasetSplits(
         train=dataset_df[dataset_df.subject.isin(train)],
         test=dataset_df[dataset_df.subject.isin(test)],
         val=dataset_df[dataset_df.subject.isin(val)])