def test_create_tf_dataset_from_all_clients(self):

        # Expands `CLIENT {N}` into N clients which add range(N) to the feature.
        def expand_client_id(client_id):
            return [
                client_id + '-' + str(i) for i in range(int(client_id[-1]))
            ]

        def make_transform_fn(client_id):
            split_client_id = tf.strings.split(client_id, '-')
            index = tf.cast(tf.strings.to_number(split_client_id[1]), tf.int32)
            return lambda x: x + index

        reduce_client_id = lambda client_id: tf.strings.split(client_id,
                                                              sep='-')[0]

        # pyformat: disable
        raw_data = {
            'CLIENT 1': [0],  # expanded to [0]
            'CLIENT 2': [1, 3, 5],  # expanded to [1, 3, 5], [2, 4, 6]
            'CLIENT 3': [7, 10]  # expanded to [7, 10], [8, 11], [9, 12]
        }
        # pyformat: enable
        client_data = from_tensor_slices_client_data.TestClientData(raw_data)
        transformed_client_data = transforming_client_data.TransformingClientData(
            client_data, make_transform_fn, expand_client_id, reduce_client_id)

        flat_data = transformed_client_data.create_tf_dataset_from_all_clients(
        )
        self.assertIsInstance(flat_data, tf.data.Dataset)
        all_features = [batch.numpy() for batch in flat_data]
        self.assertCountEqual(all_features, range(13))
示例#2
0
def get_infinite(emnist_client_data, num_pseudo_clients):
  """Converts a Federated EMNIST dataset into an Infinite Federated EMNIST set.

  Infinite Federated EMNIST expands each writer from the EMNIST dataset into
  some number of pseudo-clients each of whose characters are the same but apply
  a fixed random affine transformation to the original user's characters. The
  distribution over affine transformation is approximately equivalent to the one
  described at https://www.cs.toronto.edu/~tijmen/affNIST/. It applies the
  following transformations in this order:

    1. A random rotation chosen uniformly between -20 and 20 degrees.
    2. A random shearing adding between -0.2 to 0.2 of the x coordinate to the
       y coordinate (after centering).
    3. A random scaling between 0.8 and 1.25 (sampled log uniformly).
    4. A random translation between -5 and 5 pixels in both the x and y axes.

  Args:
    emnist_client_data: The `tff.simulation.datasets.ClientData` to convert.
    num_pseudo_clients: How many pseudo-clients to generate for each real
      client. Each pseudo-client is formed by applying a given random affine
      transformation to the characters written by a given real user. The first
      pseudo-client for a given user applies the identity transformation, so the
      original users are always included.

  Returns:
    An expanded `tff.simulation.datasets.ClientData`.
  """
  num_client_ids = len(emnist_client_data.client_ids)

  return transforming_client_data.TransformingClientData(
      raw_client_data=emnist_client_data,
      make_transform_fn=_make_transform_fn,
      num_transformed_clients=(num_client_ids * num_pseudo_clients))
 def test_create_tf_dataset_from_all_clients(self):
     client_data = from_tensor_slices_client_data.FromTensorSlicesClientData(
         TEST_DATA)
     num_transformed_clients = 9
     transformed_client_data = transforming_client_data.TransformingClientData(
         client_data, _test_transform_cons, num_transformed_clients)
     expansion_factor = num_transformed_clients // len(TEST_DATA)
     tf_dataset = transformed_client_data.create_tf_dataset_from_all_clients(
     )
     self.assertIsInstance(tf_dataset, tf.data.Dataset)
     expected_examples = []
     for expected_data in TEST_DATA.values():
         for index in range(expansion_factor):
             for i in range(len(expected_data['x'])):
                 example = {
                     k: v[i].copy()
                     for k, v in expected_data.items()
                 }
                 example['x'] += 10 * index
                 expected_examples.append(example)
     for actual in tf_dataset:
         actual = self.evaluate(actual)
         expected = expected_examples.pop(0)
         self.assertCountEqual(actual, expected)
     self.assertEmpty(expected_examples)
示例#4
0
 def test_default_num_transformed_clients(self):
     client_data = from_tensor_slices_client_data.TestClientData(TEST_DATA)
     transformed_client_data = transforming_client_data.TransformingClientData(
         client_data, _test_transform_cons)
     client_ids = transformed_client_data.client_ids
     self.assertLen(client_ids, len(TEST_DATA))
     self.assertFalse(transformed_client_data._has_pseudo_clients)
示例#5
0
 def test_client_ids_property(self):
     client_data = from_tensor_slices_client_data.TestClientData(TEST_DATA)
     num_transformed_clients = 7
     transformed_client_data = transforming_client_data.TransformingClientData(
         client_data, _test_transform_cons, num_transformed_clients)
     client_ids = transformed_client_data.client_ids
     self.assertLen(client_ids, num_transformed_clients)
     for client_id in client_ids:
         self.assertIsInstance(client_id, str)
     self.assertListEqual(client_ids, sorted(client_ids))
     self.assertTrue(transformed_client_data._has_pseudo_clients)
示例#6
0
 def test_fail_on_bad_client_id(self):
     client_data = from_tensor_slices_client_data.TestClientData(TEST_DATA)
     transformed_client_data = transforming_client_data.TransformingClientData(
         client_data, _test_transform_cons, 7)
     # The following three should be valid.
     transformed_client_data.create_tf_dataset_for_client('CLIENT A_1')
     transformed_client_data.create_tf_dataset_for_client('CLIENT B_1')
     transformed_client_data.create_tf_dataset_for_client('CLIENT A_2')
     # This should not be valid: no corresponding client.
     with self.assertRaisesRegex(
             ValueError,
             'client_id must be a valid string from client_ids.'):
         transformed_client_data.create_tf_dataset_for_client('CLIENT D_0')
     # This should not be valid: index out of range.
     with self.assertRaisesRegex(
             ValueError,
             'client_id must be a valid string from client_ids.'):
         transformed_client_data.create_tf_dataset_for_client('CLIENT B_2')
示例#7
0
 def test_create_tf_dataset_for_client(self):
     client_data = from_tensor_slices_client_data.TestClientData(TEST_DATA)
     transformed_client_data = transforming_client_data.TransformingClientData(
         client_data, _test_transform_cons, 9)
     for client_id in transformed_client_data.client_ids:
         tf_dataset = transformed_client_data.create_tf_dataset_for_client(
             client_id)
         self.assertIsInstance(tf_dataset, tf.data.Dataset)
         pattern = r'^(.*)_(\d*)$'
         match = re.search(pattern, client_id)
         client = match.group(1)
         index = int(match.group(2))
         for i, actual in enumerate(tf_dataset):
             actual = self.evaluate(actual)
             expected = {
                 k: v[i].copy()
                 for k, v in TEST_DATA[client].items()
             }
             expected['x'] += 10 * index
             self.assertCountEqual(actual, expected)
             for k, v in actual.items():
                 self.assertAllEqual(v, expected[k])
 def test_default_num_transformed_clients(self):
     transformed_client_data = transforming_client_data.TransformingClientData(
         TEST_CLIENT_DATA, _make_transform_raw)
     client_ids = transformed_client_data.client_ids
     self.assertCountEqual(client_ids, TEST_DATA.keys())
    return fn


NUM_EXPANDED_CLIENTS = 3


def test_expand_client_id(client_id):
    return [str(i) + '_' + client_id for i in range(NUM_EXPANDED_CLIENTS)]


def test_reduce_client_id(client_id):
    return tf.strings.split(client_id, sep='_')[1]


TRANSFORMED_CLIENT_DATA = transforming_client_data.TransformingClientData(
    TEST_CLIENT_DATA, _make_transform_expanded, test_expand_client_id,
    test_reduce_client_id)


class TransformingClientDataTest(tf.test.TestCase):
    def test_client_ids_property(self):
        num_transformed_clients = len(TEST_DATA) * NUM_EXPANDED_CLIENTS
        client_ids = TRANSFORMED_CLIENT_DATA.client_ids
        self.assertLen(client_ids, num_transformed_clients)
        for client_id in client_ids:
            self.assertIsInstance(client_id, str)
        self.assertListEqual(client_ids, sorted(client_ids))

    def test_default_num_transformed_clients(self):
        transformed_client_data = transforming_client_data.TransformingClientData(
            TEST_CLIENT_DATA, _make_transform_raw)