def test_flexible_splitter_local(self, create_all_datasets): dataset = create_all_datasets splitter = FlexibleSplitter( [0.3, 0.3, 0.4], configuration=[[ 0.33, 0.33, 0.33, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0 ], [0.33, 0.33, 0.33, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0], [0.33, 0.33, 0.33, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0]], val_set='local', test_set='local') partners_list = [ Partner(i) for i in range(len(splitter.amounts_per_partner)) ] if dataset.num_classes == 10: splitter.split(partners_list, dataset) for p in partners_list: assert len( p.y_val ) > 0, "validation set is empty in spite of the val_set == 'local'" assert len( p.y_test ) > 0, "test set is empty in spite of the val_set == 'local'" assert len(p.x_train) == len( p.y_train), 'labels and samples numbers mismatches' assert len( p.labels ) < dataset.num_classes, f'Partner {p.id} has all labels.'
def test_advanced_splitter_local(self, create_all_datasets): dataset = create_all_datasets splitter = AdvancedSplitter( [0.3, 0.3, 0.4], configuration=[[4 * (dataset.num_classes // 10), "specific"], [6 * (dataset.num_classes // 10), "shared"], [4 * (dataset.num_classes // 10), "shared"]], val_set='local', test_set='local') partners_list = [ Partner(i) for i in range(len(splitter.amounts_per_partner)) ] if dataset.num_classes >= 10: splitter.split(partners_list, dataset) for p in partners_list: assert len( p.y_val ) > 0, "validation set is empty in spite of the val_set == 'local'" assert len( p.y_test ) > 0, "test set is empty in spite of the val_set == 'local'" assert len(p.x_train) == len( p.y_train), 'labels and samples numbers mismatches' if dataset.num_classes >= 3: assert len( p.labels ) < dataset.num_classes, f'Partner {p.id} has all labels.' else: with pytest.raises(Exception): splitter.split(partners_list, dataset)
def test_random_splitter_global(self, create_all_datasets): splitter = RandomSplitter([0.1, 0.2, 0.3, 0.4]) dataset = create_all_datasets partners_list = [ Partner(i) for i in range(len(splitter.amounts_per_partner)) ] splitter.split(partners_list, dataset) for p in partners_list: assert len( p.y_val ) == 0, "validation set is not empty in spite of the val_set == 'global'" assert len( p.y_test ) == 0, "test set is not empty in spite of the val_set == 'global'" assert len(p.x_train) == len( p.y_train), 'labels and samples numbers mismatches' assert (p.final_nb_samples / len(dataset.y_train) - splitter.amounts_per_partner[p.id]) \ < (1 / len(dataset.y_train)), "Amounts of data not respected"
def test_stratified_splitter_local(self, create_all_datasets): splitter = StratifiedSplitter([0.1, 0.2, 0.3, 0.4], val_set='local', test_set='local') dataset = create_all_datasets partners_list = [ Partner(i) for i in range(len(splitter.amounts_per_partner)) ] splitter.split(partners_list, dataset) for p in partners_list: assert len( p.y_val ) > 0, "validation set is empty in spite of the val_set == 'local'" assert len( p.y_test ) > 0, "test set is empty in spite of the val_set == 'local'" assert len(p.x_train) == len( p.y_train), 'labels and samples numbers mismatches' if dataset.num_classes >= 3: assert len( p.labels ) < dataset.num_classes, f'Partner {p.id} has all labels.' assert (p.final_nb_samples / len(dataset.y_train) - splitter.amounts_per_partner[p.id]) \ < (1 / len(dataset.y_train)), "Amounts of data not respected"
def create_Partner(create_all_datasets): data = create_all_datasets partner = Partner(0) partner.y_train = data.y_train[:int(len(data.y_train) / 10)] partner.x_train = data.x_train[:int(len(data.x_train) / 10)] return partner