def test_flexible_splitter_local(self, create_all_datasets):
     dataset = create_all_datasets
     splitter = FlexibleSplitter(
         [0.3, 0.3, 0.4],
         configuration=[[
             0.33, 0.33, 0.33, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0
         ], [0.33, 0.33, 0.33, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0,
             0.0], [0.33, 0.33, 0.33, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0]],
         val_set='local',
         test_set='local')
     partners_list = [
         Partner(i) for i in range(len(splitter.amounts_per_partner))
     ]
     if dataset.num_classes == 10:
         splitter.split(partners_list, dataset)
         for p in partners_list:
             assert len(
                 p.y_val
             ) > 0, "validation set is empty in spite of the val_set == 'local'"
             assert len(
                 p.y_test
             ) > 0, "test set is empty in spite of the val_set == 'local'"
             assert len(p.x_train) == len(
                 p.y_train), 'labels and samples numbers mismatches'
             assert len(
                 p.labels
             ) < dataset.num_classes, f'Partner {p.id} has all labels.'
 def test_advanced_splitter_local(self, create_all_datasets):
     dataset = create_all_datasets
     splitter = AdvancedSplitter(
         [0.3, 0.3, 0.4],
         configuration=[[4 * (dataset.num_classes // 10), "specific"],
                        [6 * (dataset.num_classes // 10), "shared"],
                        [4 * (dataset.num_classes // 10), "shared"]],
         val_set='local',
         test_set='local')
     partners_list = [
         Partner(i) for i in range(len(splitter.amounts_per_partner))
     ]
     if dataset.num_classes >= 10:
         splitter.split(partners_list, dataset)
         for p in partners_list:
             assert len(
                 p.y_val
             ) > 0, "validation set is empty in spite of the val_set == 'local'"
             assert len(
                 p.y_test
             ) > 0, "test set is empty in spite of the val_set == 'local'"
             assert len(p.x_train) == len(
                 p.y_train), 'labels and samples numbers mismatches'
             if dataset.num_classes >= 3:
                 assert len(
                     p.labels
                 ) < dataset.num_classes, f'Partner {p.id} has all labels.'
     else:
         with pytest.raises(Exception):
             splitter.split(partners_list, dataset)
 def test_random_splitter_global(self, create_all_datasets):
     splitter = RandomSplitter([0.1, 0.2, 0.3, 0.4])
     dataset = create_all_datasets
     partners_list = [
         Partner(i) for i in range(len(splitter.amounts_per_partner))
     ]
     splitter.split(partners_list, dataset)
     for p in partners_list:
         assert len(
             p.y_val
         ) == 0, "validation set is not empty in spite of the val_set == 'global'"
         assert len(
             p.y_test
         ) == 0, "test set is not empty in spite of the val_set == 'global'"
         assert len(p.x_train) == len(
             p.y_train), 'labels and samples numbers mismatches'
         assert (p.final_nb_samples / len(dataset.y_train) - splitter.amounts_per_partner[p.id]) \
                < (1 / len(dataset.y_train)), "Amounts of data not respected"
 def test_stratified_splitter_local(self, create_all_datasets):
     splitter = StratifiedSplitter([0.1, 0.2, 0.3, 0.4],
                                   val_set='local',
                                   test_set='local')
     dataset = create_all_datasets
     partners_list = [
         Partner(i) for i in range(len(splitter.amounts_per_partner))
     ]
     splitter.split(partners_list, dataset)
     for p in partners_list:
         assert len(
             p.y_val
         ) > 0, "validation set is empty in spite of the val_set == 'local'"
         assert len(
             p.y_test
         ) > 0, "test set is empty in spite of the val_set == 'local'"
         assert len(p.x_train) == len(
             p.y_train), 'labels and samples numbers mismatches'
         if dataset.num_classes >= 3:
             assert len(
                 p.labels
             ) < dataset.num_classes, f'Partner {p.id} has all labels.'
         assert (p.final_nb_samples / len(dataset.y_train) - splitter.amounts_per_partner[p.id]) \
                < (1 / len(dataset.y_train)), "Amounts of data not respected"
def create_Partner(create_all_datasets):
    data = create_all_datasets
    partner = Partner(0)
    partner.y_train = data.y_train[:int(len(data.y_train) / 10)]
    partner.x_train = data.x_train[:int(len(data.x_train) / 10)]
    return partner