def test_concrete_client_data(self): client_ids = [1, 2, 3] def create_dataset_fn(client_id): num_examples = client_id return tf.data.Dataset.range(num_examples) client_data = cd.ConcreteClientData( client_ids=client_ids, create_tf_dataset_for_client_fn=create_dataset_fn) self.assertEqual(client_data.output_types, tf.int64) self.assertEqual(client_data.output_shapes, ()) def length(ds): return tf.data.experimental.cardinality(ds).numpy() for i in client_ids: self.assertEqual( length(client_data.create_tf_dataset_for_client(i)), i) # Preprocess to only take the first example from each client client_data = client_data.preprocess(lambda d: d.take(1)) for i in client_ids: self.assertEqual( length(client_data.create_tf_dataset_for_client(i)), 1)
def test_client_sampling_with_one_client(self): tff_dataset = client_data.ConcreteClientData( [2], create_tf_dataset_for_client) client_sampling_fn = sampling_utils.build_uniform_client_sampling_fn( tff_dataset, clients_per_round=1) client_ids = client_sampling_fn(round_num=7) self.assertEqual(client_ids, [2])
def test_client_sampling_fn_without_random_seed(self): tff_dataset = client_data.ConcreteClientData( list(range(100)), create_tf_dataset_for_client) client_sampling_fn = sampling_utils.build_uniform_client_sampling_fn( tff_dataset, clients_per_round=50) client_ids_1 = client_sampling_fn(round_num=0) client_ids_2 = client_sampling_fn(round_num=0) self.assertNotEqual(client_ids_1, client_ids_2)
def test_different_random_seed_give_different_clients(self): tff_dataset = client_data.ConcreteClientData( list(range(100)), create_tf_dataset_for_client) client_sampling_fn_1 = sampling_utils.build_uniform_client_sampling_fn( tff_dataset, clients_per_round=50, random_seed=1) client_ids_1 = client_sampling_fn_1(round_num=1001) client_sampling_fn_2 = sampling_utils.build_uniform_client_sampling_fn( tff_dataset, clients_per_round=50, random_seed=2) client_ids_2 = client_sampling_fn_2(round_num=1001) self.assertNotEqual(client_ids_1, client_ids_2)
def test_client_sampling_fn_with_random_seed(self): tff_dataset = client_data.ConcreteClientData( [0, 1, 2, 3, 4], create_tf_dataset_for_client) client_sampling_fn_1 = sampling_utils.build_uniform_client_sampling_fn( tff_dataset, clients_per_round=1, random_seed=363) client_ids_1 = client_sampling_fn_1(round_num=5) client_sampling_fn_2 = sampling_utils.build_uniform_client_sampling_fn( tff_dataset, clients_per_round=1, random_seed=363) client_ids_2 = client_sampling_fn_2(round_num=5) self.assertEqual(client_ids_1, client_ids_2)