Esempio n. 1
0
    def test_concrete_client_data(self):
        client_ids = [1, 2, 3]

        def create_dataset_fn(client_id):
            num_examples = client_id
            return tf.data.Dataset.range(num_examples)

        client_data = cd.ConcreteClientData(
            client_ids=client_ids,
            create_tf_dataset_for_client_fn=create_dataset_fn)

        self.assertEqual(client_data.output_types, tf.int64)
        self.assertEqual(client_data.output_shapes, ())

        def length(ds):
            return tf.data.experimental.cardinality(ds).numpy()

        for i in client_ids:
            self.assertEqual(
                length(client_data.create_tf_dataset_for_client(i)), i)

        # Preprocess to only take the first example from each client
        client_data = client_data.preprocess(lambda d: d.take(1))
        for i in client_ids:
            self.assertEqual(
                length(client_data.create_tf_dataset_for_client(i)), 1)
Esempio n. 2
0
 def test_client_sampling_with_one_client(self):
     tff_dataset = client_data.ConcreteClientData(
         [2], create_tf_dataset_for_client)
     client_sampling_fn = sampling_utils.build_uniform_client_sampling_fn(
         tff_dataset, clients_per_round=1)
     client_ids = client_sampling_fn(round_num=7)
     self.assertEqual(client_ids, [2])
Esempio n. 3
0
    def test_client_sampling_fn_without_random_seed(self):
        tff_dataset = client_data.ConcreteClientData(
            list(range(100)), create_tf_dataset_for_client)
        client_sampling_fn = sampling_utils.build_uniform_client_sampling_fn(
            tff_dataset, clients_per_round=50)
        client_ids_1 = client_sampling_fn(round_num=0)

        client_ids_2 = client_sampling_fn(round_num=0)
        self.assertNotEqual(client_ids_1, client_ids_2)
Esempio n. 4
0
    def test_different_random_seed_give_different_clients(self):
        tff_dataset = client_data.ConcreteClientData(
            list(range(100)), create_tf_dataset_for_client)

        client_sampling_fn_1 = sampling_utils.build_uniform_client_sampling_fn(
            tff_dataset, clients_per_round=50, random_seed=1)
        client_ids_1 = client_sampling_fn_1(round_num=1001)

        client_sampling_fn_2 = sampling_utils.build_uniform_client_sampling_fn(
            tff_dataset, clients_per_round=50, random_seed=2)
        client_ids_2 = client_sampling_fn_2(round_num=1001)

        self.assertNotEqual(client_ids_1, client_ids_2)
Esempio n. 5
0
    def test_client_sampling_fn_with_random_seed(self):
        tff_dataset = client_data.ConcreteClientData(
            [0, 1, 2, 3, 4], create_tf_dataset_for_client)

        client_sampling_fn_1 = sampling_utils.build_uniform_client_sampling_fn(
            tff_dataset, clients_per_round=1, random_seed=363)
        client_ids_1 = client_sampling_fn_1(round_num=5)

        client_sampling_fn_2 = sampling_utils.build_uniform_client_sampling_fn(
            tff_dataset, clients_per_round=1, random_seed=363)
        client_ids_2 = client_sampling_fn_2(round_num=5)

        self.assertEqual(client_ids_1, client_ids_2)