def test_shuffle_client_ids(self): tensor_slices_dict = { 'a': [1, 1], 'b': [2, 2, 2], 'c': [3], 'd': [4, 4] } all_examples = [1, 1, 2, 2, 2, 3, 4, 4] client_data = from_tensor_slices_client_data.FromTensorSlicesClientData( tensor_slices_dict) def get_flat_dataset(seed): ds = client_data.create_tf_dataset_from_all_clients(seed=seed) return [x.numpy() for x in ds] d1 = get_flat_dataset(123) d2 = get_flat_dataset(456) self.assertNotEqual(d1, d2) # Different random seeds, different order. self.assertCountEqual(d1, all_examples) self.assertCountEqual(d2, all_examples) # Test that the default behavior is to use a fresh random seed. # We could get unlucky, but we are very unlikely to get unlucky # 100 times in a row. found_not_equal = False for _ in range(100): if get_flat_dataset(seed=None) != get_flat_dataset(seed=None): found_not_equal = True break self.assertTrue(found_not_equal)
def test_dataset_computation_where_client_data_is_tuples(self): client_data = from_tensor_slices_client_data.FromTensorSlicesClientData( TEST_DATA_WITH_TUPLES) dataset_computation = client_data.dataset_computation self.assertIsInstance(dataset_computation, computation_base.Computation) expected_dataset_comp_type_signature = computation_types.FunctionType( computation_types.to_type(tf.string), computation_types.SequenceType( computation_types.TensorType( client_data.element_type_structure[0].dtype, tf.TensorShape(None)))) self.assertTrue( dataset_computation.type_signature.is_equivalent_to( expected_dataset_comp_type_signature)) # Iterate over each client, invoking the dataset_computation and ensuring # we received a tf.data.Dataset with the correct data. for client_id, expected_data in TEST_DATA_WITH_TUPLES.items(): tf_dataset = dataset_computation(client_id) self.assertIsInstance(tf_dataset, tf.data.Dataset) self.assertLen(expected_data, tf_dataset.cardinality()) # Check that everything in tf_dataset is an exact match for the contents # of expected_data at the corresponding index. for expected, actual in zip(expected_data, tf_dataset): self.assertAllEqual(np.asarray(expected), actual.numpy())
def test_create_tf_dataset_from_all_clients(self): client_data = from_tensor_slices_client_data.FromTensorSlicesClientData( TEST_DATA) num_transformed_clients = 9 transformed_client_data = transforming_client_data.TransformingClientData( client_data, _test_transform_cons, num_transformed_clients) expansion_factor = num_transformed_clients // len(TEST_DATA) tf_dataset = transformed_client_data.create_tf_dataset_from_all_clients( ) self.assertIsInstance(tf_dataset, tf.data.Dataset) expected_examples = [] for expected_data in TEST_DATA.values(): for index in range(expansion_factor): for i in range(len(expected_data['x'])): example = { k: v[i].copy() for k, v in expected_data.items() } example['x'] += 10 * index expected_examples.append(example) for actual in tf_dataset: actual = self.evaluate(actual) expected = expected_examples.pop(0) self.assertCountEqual(actual, expected) self.assertEmpty(expected_examples)
def test_dataset_computation_raises_error_if_unknown_client_id(self): client_data = from_tensor_slices_client_data.FromTensorSlicesClientData( TEST_DATA) dataset_computation = client_data.dataset_computation with self.assertRaises(tf.errors.InvalidArgumentError): dataset_computation(CLIENT_ID_NOT_IN_TEST_DATA)
def test_default_num_transformed_clients(self): client_data = from_tensor_slices_client_data.FromTensorSlicesClientData( TEST_DATA) transformed_client_data = transforming_client_data.TransformingClientData( client_data, _test_transform_cons) client_ids = transformed_client_data.client_ids self.assertLen(client_ids, len(TEST_DATA)) self.assertFalse(transformed_client_data._has_pseudo_clients)
def test_deprecation_warning_raised_on_init(self): tensor_slices_dict = {'a': [1, 2, 3], 'b': [4, 5]} with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') from_tensor_slices_client_data.FromTensorSlicesClientData( tensor_slices_dict) self.assertNotEmpty(w) self.assertEqual(w[0].category, DeprecationWarning) self.assertRegex( str(w[0].message), 'tff.simulation.FromTensorSlicesClientData is deprecated')
def test_client_ids_property(self): client_data = from_tensor_slices_client_data.FromTensorSlicesClientData( TEST_DATA) num_transformed_clients = 7 transformed_client_data = transforming_client_data.TransformingClientData( client_data, _test_transform_cons, num_transformed_clients) client_ids = transformed_client_data.client_ids self.assertLen(client_ids, num_transformed_clients) for client_id in client_ids: self.assertIsInstance(client_id, str) self.assertListEqual(client_ids, sorted(client_ids)) self.assertTrue(transformed_client_data._has_pseudo_clients)
def test_where_client_data_is_tuples(self): client_data = from_tensor_slices_client_data.FromTensorSlicesClientData( TEST_DATA_WITH_TUPLES) self.assertCountEqual(TEST_DATA_WITH_TUPLES.keys(), client_data.client_ids) self.assertEqual(client_data.element_type_structure, (tf.TensorSpec( shape=(), dtype=tf.int32), tf.TensorSpec(shape=(), dtype=tf.int32))) for client_id in TEST_DATA_WITH_TUPLES: self.assertSameDatasets( tf.data.Dataset.from_tensor_slices(TEST_DATA_WITH_TUPLES[client_id]), client_data.create_tf_dataset_for_client(client_id))
def get_synthetic() -> client_data.ClientData: """Creates `tff.simulation.ClientData` for a synthetic in-memory example of Shakespeare. The returned `tff.simulation.ClientData` will have the same data schema as `load_data()`, but uses a very small set of client data loaded in-memory. This synthetic data is useful for validation in small tests. Returns: A `tff.simulation.ClientData` of synthentic Shakespeare text. """ return from_tensor_slices_client_data.FromTensorSlicesClientData( _SYNTHETIC_SHAKESPEARE_DATA)
def test_client_ids_property(self): client_data = from_tensor_slices_client_data.FromTensorSlicesClientData( TEST_DATA) num_transformed_clients = 7 transformed_client_data = transforming_client_data.TransformingClientData( client_data, _test_transform_cons, num_transformed_clients) client_ids = transformed_client_data.client_ids # Check length of client_ids. self.assertLen(client_ids, 7) # Check that they are all strings. for client_id in client_ids: self.assertIsInstance(client_id, str) # Check ids are sorted. self.assertListEqual(client_ids, sorted(client_ids))
def get_synthetic(): """Returns a small synthetic dataset for testing. Provides two clients, each client with only 3 examples. The examples are derived from a fixed set of examples in the larger dataset, but are not exact copies. Returns: A `tff.simulation.ClientData` object that matches the characteristics (other than size) of those provided by `tff.simulation.datasets.stackoverflow.load_data`. """ return from_tensor_slices_client_data.FromTensorSlicesClientData( _SYNTHETIC_STACKOVERFLOW_DATA)
def test_basic(self): tensor_slices_dict = {'a': [1, 2, 3], 'b': [4, 5]} client_data = from_tensor_slices_client_data.FromTensorSlicesClientData( tensor_slices_dict) self.assertCountEqual(client_data.client_ids, ['a', 'b']) self.assertEqual(client_data.output_types, tf.int32) self.assertEqual(client_data.output_shapes, ()) def as_list(dataset): return [x.numpy() for x in dataset] self.assertEqual( as_list(client_data.create_tf_dataset_for_client('a')), [1, 2, 3]) self.assertEqual( as_list(client_data.create_tf_dataset_for_client('b')), [4, 5])
def test_basic(self): tensor_slices_dict = {'a': [1, 2, 3], 'b': [4, 5]} client_data = from_tensor_slices_client_data.FromTensorSlicesClientData( tensor_slices_dict) self.assertCountEqual(client_data.client_ids, ['a', 'b']) self.assertEqual(client_data.element_type_structure, tf.TensorSpec(shape=(), dtype=tf.int32)) def as_list(dataset): return [self.evaluate(x) for x in dataset] self.assertEqual( as_list(client_data.create_tf_dataset_for_client('a')), [1, 2, 3]) self.assertEqual( as_list(client_data.create_tf_dataset_for_client('b')), [4, 5])
def test_fail_on_bad_client_id(self): client_data = from_tensor_slices_client_data.FromTensorSlicesClientData( TEST_DATA) transformed_client_data = transforming_client_data.TransformingClientData( client_data, _test_transform_cons, 7) # The following three should be valid. transformed_client_data.create_tf_dataset_for_client('CLIENT A_1') transformed_client_data.create_tf_dataset_for_client('CLIENT B_1') transformed_client_data.create_tf_dataset_for_client('CLIENT A_2') # This should not be valid: no corresponding client. with self.assertRaisesRegex( ValueError, 'client_id must be a valid string from client_ids.'): transformed_client_data.create_tf_dataset_for_client('CLIENT D_0') # This should not be valid: index out of range. with self.assertRaisesRegex( ValueError, 'client_id must be a valid string from client_ids.'): transformed_client_data.create_tf_dataset_for_client('CLIENT B_2')
def test_where_client_data_is_ordered_dicts(self): client_data = from_tensor_slices_client_data.FromTensorSlicesClientData( TEST_DATA_WITH_ORDEREDDICTS) self.assertCountEqual(TEST_DATA_WITH_ORDEREDDICTS.keys(), client_data.client_ids) self.assertEqual( collections.OrderedDict([ ('x', tf.TensorSpec(shape=(2,), dtype=tf.int32)), ('y', tf.TensorSpec(shape=(), dtype=tf.float32)), ('z', tf.TensorSpec(shape=(), dtype=tf.string)) ]), client_data.element_type_structure) for client_id in TEST_DATA_WITH_ORDEREDDICTS: self.assertSameDatasetsOfDicts( tf.data.Dataset.from_tensor_slices( TEST_DATA_WITH_ORDEREDDICTS[client_id]), client_data.create_tf_dataset_for_client(client_id))
def test_create_tf_dataset_for_client(self): client_data = from_tensor_slices_client_data.FromTensorSlicesClientData( TEST_DATA) transformed_client_data = transforming_client_data.TransformingClientData( client_data, _test_transform_cons, 9) for client_id in transformed_client_data.client_ids: tf_dataset = transformed_client_data.create_tf_dataset_for_client( client_id) self.assertIsInstance(tf_dataset, tf.data.Dataset) pattern = r'^(.*)_(\d*)$' match = re.search(pattern, client_id) client = match.group(1) index = int(match.group(2)) for i, actual in enumerate(tf_dataset): actual = self.evaluate(actual) expected = {k: v[i].copy() for k, v in TEST_DATA[client].items()} expected['x'] += 10 * index self.assertCountEqual(actual, expected) for k, v in actual.items(): self.assertAllEqual(v, expected[k])
def get_synthetic(num_clients=2): """Quickly returns a small synthetic dataset, useful for unit tests, etc. Each client produced has exactly 10 examples, one of each digit. The images are derived from a fixed set of hard-coded images, and transformed using `tff.simulation.datasets.emnist.infinite_emnist` to produce the desired number of clients. Args: num_clients: The number of synthetic clients to generate. Returns: A `tff.simulation.ClientData` object that matches the characteristics (other than size) of those provided by `tff.simulation.datasets.emnist.load_data`. """ return get_infinite( # Base ClientData with one client from_tensor_slices_client_data.FromTensorSlicesClientData( {'synthetic': _get_synthetic_digits_data()}), num_pseudo_clients=num_clients)
def test_dataset_computation_where_client_data_is_ordered_dicts(self): client_data = from_tensor_slices_client_data.FromTensorSlicesClientData( TEST_DATA_WITH_ORDEREDDICTS) dataset_computation = client_data.dataset_computation self.assertIsInstance(dataset_computation, computation_base.Computation) expected_dataset_comp_type_signature = computation_types.FunctionType( computation_types.to_type(tf.string), computation_types.SequenceType( collections.OrderedDict([ ('x', computation_types.TensorType( client_data.element_type_structure['x'].dtype, tf.TensorShape(2))), ('y', computation_types.TensorType( client_data.element_type_structure['y'].dtype, None)), ('z', computation_types.TensorType( client_data.element_type_structure['z'].dtype, None)) ]))) self.assertTrue( dataset_computation.type_signature.is_equivalent_to( expected_dataset_comp_type_signature)) # Iterate over each client, invoking the computation and ensuring # we received a tf.data.Dataset with the correct data. for client_id in TEST_DATA_WITH_ORDEREDDICTS: dataset = dataset_computation(client_id) self.assertIsInstance(dataset, tf.data.Dataset) expected_dataset = tf.data.Dataset.from_tensor_slices( TEST_DATA_WITH_ORDEREDDICTS[client_id]) self.assertSameDatasetsOfDicts(expected_dataset, dataset)
def test_init_raises_error_if_slices_is_not_dict(self): with self.assertRaises(TypeError): from_tensor_slices_client_data.FromTensorSlicesClientData( TEST_DATA_NOT_DICT)
def test_init_raises_error_if_slices_are_namedtuples(self): with self.assertRaises(TypeError): from_tensor_slices_client_data.FromTensorSlicesClientData( TEST_DATA_WITH_NAMEDTUPLES)
def test_empty(self): with self.assertRaises(ValueError): from_tensor_slices_client_data.FromTensorSlicesClientData( {'a': []})
def test_init_raises_error_if_slices_are_inconsistent_type(self): with self.assertRaises(TypeError): from_tensor_slices_client_data.FromTensorSlicesClientData( TEST_DATA_WITH_INCONSISTENT_TYPE)
def test_init_raises_error_if_slices_are_part_list_and_part_dict(self): with self.assertRaises(TypeError): from_tensor_slices_client_data.FromTensorSlicesClientData( TEST_DATA_WITH_PART_LIST_AND_PART_DICT)