def test_create_tf_dataset_from_all_clients(self): client_data = hdf5_client_data.HDF5ClientData( TransformingClientDataTest.test_data_filepath) num_transformed_clients = 9 transformed_client_data = transforming_client_data.TransformingClientData( client_data, _test_transform, num_transformed_clients) expansion_factor = num_transformed_clients // len(TEST_DATA) tf_dataset = transformed_client_data.create_tf_dataset_from_all_clients( ) self.assertIsInstance(tf_dataset, tf.data.Dataset) expected_examples = [] for expected_data in six.itervalues(TEST_DATA): for index in range(expansion_factor): for i in range(len(expected_data['x'])): example = { k: v[i] for k, v in six.iteritems(expected_data) } example['x'] += 10 * index expected_examples.append(example) for actual in tf_dataset: expected = expected_examples.pop(0) actual = tf.contrib.framework.nest.map_structure( lambda t: t.numpy(), actual) self.assertCountEqual(actual, expected) self.assertEmpty(expected_examples)
def infinite_emnist(emnist_client_data, num_pseudo_clients): """Converts a Federated EMNIST dataset into an Infinite Federated EMNIST set. Infinite Federated EMNIST expands each writer from the EMNIST dataset into some number of pseudo-clients each of whose characters are the same but apply a fixed random affine transformation to the original user's characters. The distribution over affine transformation is approximately equivalent to the one described at https://www.cs.toronto.edu/~tijmen/affNIST/. It applies the following transformations in this order: 1. A random rotation chosen uniformly between -20 and 20 degrees. 2. A random shearing adding between -0.2 to 0.2 of the x coordinate to the y coordinate (after centering). 3. A random scaling between 0.8 and 1.25 (sampled log uniformly). 4. A random translation between -5 and 5 pixels in both the x and y axes. Args: emnist_client_data: The `tff.simulation.ClientData` to convert. num_pseudo_clients: How many pseudo-clients to generate for each real client. Each pseudo-client is formed by applying a given random affine transformation to the characters written by a given real user. The first pseudo-client for a given user applies the identity transformation, so the original users are always included. Returns: An expanded `tff.simulation.ClientData`. """ num_client_ids = len(emnist_client_data.client_ids) return transforming_client_data.TransformingClientData( raw_client_data=emnist_client_data, make_transform_fn=_make_transform_fn, num_transformed_clients=(num_client_ids * num_pseudo_clients))
def test_create_tf_dataset_for_client(self): client_data = hdf5_client_data.HDF5ClientData( TransformingClientDataTest.test_data_filepath) transformed_client_data = transforming_client_data.TransformingClientData( client_data, _test_transform, 9) for client_id in transformed_client_data.client_ids: tf_dataset = transformed_client_data.create_tf_dataset_for_client( client_id) self.assertIsInstance(tf_dataset, tf.data.Dataset) pattern = r'^(.*)_(\d*)$' match = re.search(pattern, client_id) client = match.group(1) index = int(match.group(2)) for i, actual in enumerate(tf_dataset): expected = { k: v[i] for k, v in six.iteritems(TEST_DATA[client]) } expected['x'] = expected['x'] + 10 * index self.assertCountEqual(actual, expected) for k, v in six.iteritems(actual): self.assertAllEqual(v.numpy(), expected[k])
def test_create_tf_dataset_from_all_clients(self): client_data = hdf5_client_data.HDF5ClientData( TransformingClientDataTest.test_data_filepath) num_transformed_clients = 9 transformed_client_data = transforming_client_data.TransformingClientData( client_data, _test_transform_cons, num_transformed_clients) expansion_factor = num_transformed_clients // len(TEST_DATA) tf_dataset = transformed_client_data.create_tf_dataset_from_all_clients( ) self.assertIsInstance(tf_dataset, tf.data.Dataset) expected_examples = [] for expected_data in TEST_DATA.values(): for index in range(expansion_factor): for i in range(len(expected_data['x'])): example = { k: v[i].copy() for k, v in expected_data.items() } example['x'] += 10 * index expected_examples.append(example) for actual in tf_dataset: actual = self.evaluate(actual) expected = expected_examples.pop(0) self.assertCountEqual(actual, expected) self.assertEmpty(expected_examples)
def test_default_num_transformed_clients(self): client_data = from_tensor_slices_client_data.FromTensorSlicesClientData( TEST_DATA) transformed_client_data = transforming_client_data.TransformingClientData( client_data, _test_transform_cons) client_ids = transformed_client_data.client_ids self.assertLen(client_ids, len(TEST_DATA)) self.assertFalse(transformed_client_data._has_pseudo_clients)
def test_client_ids_property(self): client_data = hdf5_client_data.HDF5ClientData( TransformingClientDataTest.test_data_filepath) expansion_factor = 2.5 transformed_client_data = transforming_client_data.TransformingClientData( client_data, lambda: 0, expansion_factor) self.assertLen(transformed_client_data.client_ids, int(len(TEST_DATA) * expansion_factor))
def test_client_ids_property(self): client_data = from_tensor_slices_client_data.FromTensorSlicesClientData( TEST_DATA) num_transformed_clients = 7 transformed_client_data = transforming_client_data.TransformingClientData( client_data, _test_transform_cons, num_transformed_clients) client_ids = transformed_client_data.client_ids self.assertLen(client_ids, num_transformed_clients) for client_id in client_ids: self.assertIsInstance(client_id, str) self.assertListEqual(client_ids, sorted(client_ids)) self.assertTrue(transformed_client_data._has_pseudo_clients)
def test_client_ids_property(self): client_data = from_tensor_slices_client_data.FromTensorSlicesClientData( TEST_DATA) with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') transforming_client_data.TransformingClientData( client_data, _test_transform_cons) self.assertNotEmpty(w) self.assertEqual(w[0].category, DeprecationWarning) self.assertRegex( str(w[0].message), 'tff.simulation.TransformingClientData is deprecated')
def test_client_ids_property(self): client_data = from_tensor_slices_client_data.FromTensorSlicesClientData( TEST_DATA) num_transformed_clients = 7 transformed_client_data = transforming_client_data.TransformingClientData( client_data, _test_transform_cons, num_transformed_clients) client_ids = transformed_client_data.client_ids # Check length of client_ids. self.assertLen(client_ids, 7) # Check that they are all strings. for client_id in client_ids: self.assertIsInstance(client_id, str) # Check ids are sorted. self.assertListEqual(client_ids, sorted(client_ids))
def test_fail_on_bad_client_id(self): client_data = from_tensor_slices_client_data.FromTensorSlicesClientData( TEST_DATA) transformed_client_data = transforming_client_data.TransformingClientData( client_data, _test_transform_cons, 7) # The following three should be valid. transformed_client_data.create_tf_dataset_for_client('CLIENT A_1') transformed_client_data.create_tf_dataset_for_client('CLIENT B_1') transformed_client_data.create_tf_dataset_for_client('CLIENT A_2') # This should not be valid: no corresponding client. with self.assertRaisesRegex( ValueError, 'client_id must be a valid string from client_ids.'): transformed_client_data.create_tf_dataset_for_client('CLIENT D_0') # This should not be valid: index out of range. with self.assertRaisesRegex( ValueError, 'client_id must be a valid string from client_ids.'): transformed_client_data.create_tf_dataset_for_client('CLIENT B_2')
def test_client_ids_property(self): client_data = hdf5_client_data.HDF5ClientData( TransformingClientDataTest.test_data_filepath) num_transformed_clients = 7 transformed_client_data = transforming_client_data.TransformingClientData( client_data, _test_transform, num_transformed_clients) client_ids = transformed_client_data.client_ids # Check length of client_ids. self.assertLen(client_ids, 7) # Check that they are all strings. for client_id in client_ids: self.assertIsInstance(client_id, str) # Check ids are sorted. self.assertListEqual(client_ids, sorted(client_ids))
def test_create_tf_dataset_for_client(self): client_data = from_tensor_slices_client_data.FromTensorSlicesClientData( TEST_DATA) transformed_client_data = transforming_client_data.TransformingClientData( client_data, _test_transform_cons, 9) for client_id in transformed_client_data.client_ids: tf_dataset = transformed_client_data.create_tf_dataset_for_client( client_id) self.assertIsInstance(tf_dataset, tf.data.Dataset) pattern = r'^(.*)_(\d*)$' match = re.search(pattern, client_id) client = match.group(1) index = int(match.group(2)) for i, actual in enumerate(tf_dataset): actual = self.evaluate(actual) expected = {k: v[i].copy() for k, v in TEST_DATA[client].items()} expected['x'] += 10 * index self.assertCountEqual(actual, expected) for k, v in actual.items(): self.assertAllEqual(v, expected[k])
def test_fail_on_bad_client_id(self): client_data = hdf5_client_data.HDF5ClientData( TransformingClientDataTest.test_data_filepath) transformed_client_data = transforming_client_data.TransformingClientData( client_data, _test_transform, 7) # The following three should be valid. transformed_client_data.create_tf_dataset_for_client('CLIENT A_1') transformed_client_data.create_tf_dataset_for_client('CLIENT B_1') transformed_client_data.create_tf_dataset_for_client('CLIENT A_2') # This should not be valid: no corresponding client. with self.assertRaisesRegex( ValueError, 'client_id must be a valid string from client_ids.'): transformed_client_data.create_tf_dataset_for_client('CLIENT D_0') # This should not be valid: index out of range. with self.assertRaisesRegex( ValueError, 'client_id must be a valid string from client_ids.'): transformed_client_data.create_tf_dataset_for_client('CLIENT B_2')