def test_dataset_length(self): cifar_train, cifar_test = cifar10_dataset.load_cifar10_federated() self.assertEqual( _compute_length_of_dataset( cifar_train.create_tf_dataset_for_client('0')), 5000) self.assertEqual( _compute_length_of_dataset( cifar_test.create_tf_dataset_for_client('0')), 1000)
def test_dataset_length_8_clients(self): cifar_train, cifar_test = cifar10_dataset.load_cifar10_federated( num_clients=8) self.assertEqual( _compute_length_of_dataset( cifar_train.create_tf_dataset_for_client('0')), 6250) self.assertEqual( _compute_length_of_dataset( cifar_test.create_tf_dataset_for_client('0')), 1250)
def test_num_clients(self): cifar_train, cifar_test = cifar10_dataset.load_cifar10_federated() self.assertEqual(len(cifar_train.client_ids), 10) self.assertEqual(len(cifar_test.client_ids), 10)