def test_client_data_constructs_with_correct_clients_and_types(self): tensor_slices_dict = {'a': [1, 2, 3], 'b': [4, 5]} client_data = from_tensor_slices_client_data.TestClientData( tensor_slices_dict) self.assertCountEqual(client_data.client_ids, ['a', 'b']) self.assertEqual(client_data.element_type_structure, tf.TensorSpec(shape=(), dtype=tf.int32))
def test_build_synthethic_iid_client_data(self): # Create a fake, very non-IID ClientData. client_datasets = collections.OrderedDict(a=[1] * 3, b=[2] * 5, c=[3] * 7) non_iid_client_data = from_tensor_slices_client_data.TestClientData( client_datasets) iid_client_data_iter = iter( dataset_utils.build_synthethic_iid_datasets(non_iid_client_data, client_dataset_size=5)) num_synthethic_clients = 3 run_results = [] for _ in range(5): actual_iid_client_datasets = [] for _ in range(num_synthethic_clients): dataset = next(iid_client_data_iter) actual_iid_client_datasets.append( [self.evaluate(x) for x in dataset]) # We expect 3 datasets: 15 examples in the global dataset, synthetic # non-iid configured for 5 examples per client. self.assertEqual([5, 5, 5], [len(d) for d in actual_iid_client_datasets]) run_results.append(actual_iid_client_datasets) # Assert no run is the same. The chance that two runs are the same is far # less than 1 in a million, flakes should be imperceptible. for i, run_a in enumerate(run_results[:-1]): for run_b in run_results[i + 1:]: self.assertNotEqual(run_a, run_b, msg=str(run_results))
def test_dataset_computation_where_client_data_is_tuples(self): client_data = from_tensor_slices_client_data.TestClientData( TEST_DATA_WITH_TUPLES) dataset_computation = client_data.dataset_computation self.assertIsInstance(dataset_computation, computation_base.Computation) expected_dataset_comp_type_signature = computation_types.FunctionType( computation_types.to_type(tf.string), computation_types.SequenceType( computation_types.TensorType( client_data.element_type_structure[0].dtype, tf.TensorShape(None)))) self.assertTrue( dataset_computation.type_signature.is_equivalent_to( expected_dataset_comp_type_signature)) # Iterate over each client, invoking the dataset_computation and ensuring # we received a tf.data.Dataset with the correct data. for client_id, expected_data in TEST_DATA_WITH_TUPLES.items(): tf_dataset = dataset_computation(client_id) self.assertIsInstance(tf_dataset, tf.data.Dataset) self.assertLen(expected_data, tf_dataset.cardinality()) # Check that everything in tf_dataset is an exact match for the contents # of expected_data at the corresponding index. for expected, actual in zip(expected_data, tf_dataset): self.assertAllEqual(np.asarray(expected), actual.numpy())
def test_create_tf_dataset_from_all_clients(self): # Expands `CLIENT {N}` into N clients which add range(N) to the feature. def expand_client_id(client_id): return [ client_id + '-' + str(i) for i in range(int(client_id[-1])) ] def make_transform_fn(client_id): split_client_id = tf.strings.split(client_id, '-') index = tf.cast(tf.strings.to_number(split_client_id[1]), tf.int32) return lambda x: x + index reduce_client_id = lambda client_id: tf.strings.split(client_id, sep='-')[0] # pyformat: disable raw_data = { 'CLIENT 1': [0], # expanded to [0] 'CLIENT 2': [1, 3, 5], # expanded to [1, 3, 5], [2, 4, 6] 'CLIENT 3': [7, 10] # expanded to [7, 10], [8, 11], [9, 12] } # pyformat: enable client_data = from_tensor_slices_client_data.TestClientData(raw_data) transformed_client_data = transforming_client_data.TransformingClientData( client_data, make_transform_fn, expand_client_id, reduce_client_id) flat_data = transformed_client_data.create_tf_dataset_from_all_clients( ) self.assertIsInstance(flat_data, tf.data.Dataset) all_features = [batch.numpy() for batch in flat_data] self.assertCountEqual(all_features, range(13))
def test_dataset_computation_raises_error_if_unknown_client_id(self): client_data = from_tensor_slices_client_data.TestClientData(TEST_DATA) dataset_computation = client_data.dataset_computation with self.assertRaises(tf.errors.InvalidArgumentError): dataset_computation(CLIENT_ID_NOT_IN_TEST_DATA)
def test_default_num_transformed_clients(self): client_data = from_tensor_slices_client_data.TestClientData(TEST_DATA) transformed_client_data = transforming_client_data.TransformingClientData( client_data, _test_transform_cons) client_ids = transformed_client_data.client_ids self.assertLen(client_ids, len(TEST_DATA)) self.assertFalse(transformed_client_data._has_pseudo_clients)
def test_shuffle_client_ids(self): tensor_slices_dict = {'a': [1, 1], 'b': [2, 2, 2], 'c': [3], 'd': [4, 4]} all_examples = [1, 1, 2, 2, 2, 3, 4, 4] client_data = from_tensor_slices_client_data.TestClientData( tensor_slices_dict) def get_flat_dataset(seed): ds = client_data.create_tf_dataset_from_all_clients(seed=seed) return [x.numpy() for x in ds] d1 = get_flat_dataset(123) d2 = get_flat_dataset(456) self.assertNotEqual(d1, d2) # Different random seeds, different order. self.assertCountEqual(d1, all_examples) self.assertCountEqual(d2, all_examples) # Test that the default behavior is to use a fresh random seed. # We could get unlucky, but we are very unlikely to get unlucky # 100 times in a row. found_not_equal = False for _ in range(100): if get_flat_dataset(seed=None) != get_flat_dataset(seed=None): found_not_equal = True break self.assertTrue(found_not_equal)
def test_serializable_dataset_fn_constructs(self): tensor_slices_dict = {'a': [1, 2, 3], 'b': [4, 5]} client_data = from_tensor_slices_client_data.TestClientData( tensor_slices_dict) def as_list(dataset): return [self.evaluate(x) for x in dataset] self.assertEqual(as_list(client_data.serializable_dataset_fn('a')), [1, 2, 3])
def test_client_ids_property(self): client_data = from_tensor_slices_client_data.TestClientData(TEST_DATA) num_transformed_clients = 7 transformed_client_data = transforming_client_data.TransformingClientData( client_data, _test_transform_cons, num_transformed_clients) client_ids = transformed_client_data.client_ids self.assertLen(client_ids, num_transformed_clients) for client_id in client_ids: self.assertIsInstance(client_id, str) self.assertListEqual(client_ids, sorted(client_ids)) self.assertTrue(transformed_client_data._has_pseudo_clients)
def test_where_client_data_is_tensors(self): client_data = from_tensor_slices_client_data.TestClientData(TEST_DATA) self.assertCountEqual(TEST_DATA.keys(), client_data.client_ids) self.assertEqual(client_data.element_type_structure, tf.TensorSpec(shape=(2,), dtype=tf.int32)) for client_id in TEST_DATA: self.assertSameDatasets( tf.data.Dataset.from_tensor_slices(TEST_DATA[client_id]), client_data.create_tf_dataset_for_client(client_id))
def get_synthetic(): """Returns a small synthetic dataset for testing. The single client produced has exactly 5 examples. The images and labels are derived from a fixed set of hard-coded images. Returns: A `tff.simulation.datasets.ClientData` object that matches the characteristics (other than size) of those provided by `tff.simulation.datasets.cifar100.load_data`. """ return from_tensor_slices_client_data.TestClientData( {'synthetic': _get_synthetic_digits_data()})
def get_synthetic(): """Returns a small synthetic dataset for testing. Provides two clients, each client with only 3 examples. The examples are derived from a fixed set of examples in the larger dataset, but are not exact copies. Returns: A `tff.simulation.datasets.ClientData` object that matches the characteristics (other than size) of those provided by `tff.simulation.datasets.stackoverflow.load_data`. """ return from_tensor_slices_client_data.TestClientData( create_synthetic_data_dictionary())
def get_synthetic() -> client_data.ClientData: """Creates `tff.simulation.datasets.ClientData` for a synthetic in-memory example of Shakespeare. The returned `tff.simulation.datasets.ClientData` will have the same data schema as `load_data()`, but uses a very small set of client data loaded in-memory. This synthetic data is useful for validation in small tests. Returns: A `tff.simulation.datasets.ClientData` of synthentic Shakespeare text. """ return from_tensor_slices_client_data.TestClientData( _SYNTHETIC_SHAKESPEARE_DATA)
def test_basic(self): tensor_slices_dict = {'a': [1, 2, 3], 'b': [4, 5]} client_data = from_tensor_slices_client_data.TestClientData( tensor_slices_dict) self.assertCountEqual(client_data.client_ids, ['a', 'b']) self.assertEqual(client_data.element_type_structure, tf.TensorSpec(shape=(), dtype=tf.int32)) def as_list(dataset): return [self.evaluate(x) for x in dataset] self.assertEqual( as_list(client_data.create_tf_dataset_for_client('a')), [1, 2, 3]) self.assertEqual( as_list(client_data.create_tf_dataset_for_client('b')), [4, 5])
def get_synthetic() -> ClientData: """Returns a small synthetic dataset for testing. The single client produced has 3 examples generated pseudo-randomly. Returns: A `tff.simulation.datasets.ClientData`. """ images = [ tf.random.stateless_normal(shape=(128, 128, 3), seed=(0, i)) for i in range(3) ] images_as_tensor = tf.cast(tf.stack(images, axis=0), dtype=tf.uint8) labels = tf.constant([0, 1, 2], dtype=tf.int64) data = collections.OrderedDict([('image/decoded', images_as_tensor), ('class', labels)]) return from_tensor_slices_client_data.TestClientData({'synthetic': data})
def test_where_client_data_is_ordered_dicts(self): client_data = from_tensor_slices_client_data.TestClientData( TEST_DATA_WITH_ORDEREDDICTS) self.assertCountEqual(TEST_DATA_WITH_ORDEREDDICTS.keys(), client_data.client_ids) self.assertEqual( collections.OrderedDict([ ('x', tf.TensorSpec(shape=(2,), dtype=tf.int32)), ('y', tf.TensorSpec(shape=(), dtype=tf.float32)), ('z', tf.TensorSpec(shape=(), dtype=tf.string)) ]), client_data.element_type_structure) for client_id in TEST_DATA_WITH_ORDEREDDICTS: self.assertSameDatasetsOfDicts( tf.data.Dataset.from_tensor_slices( TEST_DATA_WITH_ORDEREDDICTS[client_id]), client_data.create_tf_dataset_for_client(client_id))
def test_fail_on_bad_client_id(self): client_data = from_tensor_slices_client_data.TestClientData(TEST_DATA) transformed_client_data = transforming_client_data.TransformingClientData( client_data, _test_transform_cons, 7) # The following three should be valid. transformed_client_data.create_tf_dataset_for_client('CLIENT A_1') transformed_client_data.create_tf_dataset_for_client('CLIENT B_1') transformed_client_data.create_tf_dataset_for_client('CLIENT A_2') # This should not be valid: no corresponding client. with self.assertRaisesRegex( ValueError, 'client_id must be a valid string from client_ids.'): transformed_client_data.create_tf_dataset_for_client('CLIENT D_0') # This should not be valid: index out of range. with self.assertRaisesRegex( ValueError, 'client_id must be a valid string from client_ids.'): transformed_client_data.create_tf_dataset_for_client('CLIENT B_2')
def test_create_tf_dataset_for_client(self): client_data = from_tensor_slices_client_data.TestClientData(TEST_DATA) transformed_client_data = transforming_client_data.TransformingClientData( client_data, _test_transform_cons, 9) for client_id in transformed_client_data.client_ids: tf_dataset = transformed_client_data.create_tf_dataset_for_client( client_id) self.assertIsInstance(tf_dataset, tf.data.Dataset) pattern = r'^(.*)_(\d*)$' match = re.search(pattern, client_id) client = match.group(1) index = int(match.group(2)) for i, actual in enumerate(tf_dataset): actual = self.evaluate(actual) expected = { k: v[i].copy() for k, v in TEST_DATA[client].items() } expected['x'] += 10 * index self.assertCountEqual(actual, expected) for k, v in actual.items(): self.assertAllEqual(v, expected[k])
def test_dataset_computation_where_client_data_is_ordered_dicts(self): client_data = from_tensor_slices_client_data.TestClientData( TEST_DATA_WITH_ORDEREDDICTS) dataset_computation = client_data.dataset_computation self.assertIsInstance(dataset_computation, computation_base.Computation) expected_dataset_comp_type_signature = computation_types.FunctionType( computation_types.to_type(tf.string), computation_types.SequenceType( collections.OrderedDict([ ('x', computation_types.TensorType( client_data.element_type_structure['x'].dtype, tf.TensorShape(2))), ('y', computation_types.TensorType( client_data.element_type_structure['y'].dtype, None)), ('z', computation_types.TensorType( client_data.element_type_structure['z'].dtype, None)) ]))) self.assertTrue( dataset_computation.type_signature.is_equivalent_to( expected_dataset_comp_type_signature)) # Iterate over each client, invoking the computation and ensuring # we received a tf.data.Dataset with the correct data. for client_id in TEST_DATA_WITH_ORDEREDDICTS: dataset = dataset_computation(client_id) self.assertIsInstance(dataset, tf.data.Dataset) expected_dataset = tf.data.Dataset.from_tensor_slices( TEST_DATA_WITH_ORDEREDDICTS[client_id]) self.assertSameDatasetsOfDicts(expected_dataset, dataset)
def test_create_tf_dataset_from_all_clients(self): client_data = from_tensor_slices_client_data.TestClientData(TEST_DATA) num_transformed_clients = 9 transformed_client_data = transforming_client_data.TransformingClientData( client_data, _test_transform_cons, num_transformed_clients) expansion_factor = num_transformed_clients // len(TEST_DATA) tf_dataset = transformed_client_data.create_tf_dataset_from_all_clients( ) self.assertIsInstance(tf_dataset, tf.data.Dataset) expected_examples = [] for expected_data in TEST_DATA.values(): for index in range(expansion_factor): for i in range(len(expected_data['x'])): example = { k: v[i].copy() for k, v in expected_data.items() } example['x'] += 10 * index expected_examples.append(example) for actual in tf_dataset: actual = self.evaluate(actual) expected = expected_examples.pop(0) self.assertCountEqual(actual, expected) self.assertEmpty(expected_examples)
), 'CLIENT B': collections.OrderedDict( x=[[10, 11]], y=[7.0], z=['d'], ), 'CLIENT C': collections.OrderedDict( x=[[100, 101], [200, 201]], y=[8.0, 9.0], z=['e', 'f'], ), } TEST_CLIENT_DATA = from_tensor_slices_client_data.TestClientData(TEST_DATA) def _make_transform_expanded(client_id): index_str = tf.strings.split(client_id, sep='_', maxsplit=1)[0] index = tf.cast(tf.strings.to_number(index_str), tf.int32) def fn(data): return collections.OrderedDict([('x', data['x'] + 10 * index), ('y', data['y']), ('z', data['z'])]) return fn def _make_transform_raw(client_id): del client_id
def test_constructor_does_not_modify_in_place(self, tensor_slices_dict): copy_of_tensor_slices_dict = copy.deepcopy(tensor_slices_dict) from_tensor_slices_client_data.TestClientData(tensor_slices_dict) self.assertSameStructure(tensor_slices_dict, copy_of_tensor_slices_dict)
def test_init_raises_error_if_slices_are_part_list_and_part_dict(self): with self.assertRaises(TypeError): from_tensor_slices_client_data.TestClientData( TEST_DATA_WITH_PART_LIST_AND_PART_DICT)
def test_init_raises_error_if_slices_are_inconsistent_type(self): with self.assertRaises(TypeError): from_tensor_slices_client_data.TestClientData( TEST_DATA_WITH_INCONSISTENT_TYPE)
def test_init_raises_error_if_slices_are_namedtuples(self): with self.assertRaises(TypeError): from_tensor_slices_client_data.TestClientData(TEST_DATA_WITH_NAMEDTUPLES)
def test_init_raises_error_if_slices_is_not_dict(self): with self.assertRaises(TypeError): from_tensor_slices_client_data.TestClientData(TEST_DATA_NOT_DICT)
def test_raises_error_if_empty_client_found(self): with self.assertRaises(ValueError): from_tensor_slices_client_data.TestClientData({'a': []})