コード例 #1
0
    def test_create_tf_dataset_from_all_clients(self):
        client_data = hdf5_client_data.HDF5ClientData(
            TransformingClientDataTest.test_data_filepath)

        num_transformed_clients = 9
        transformed_client_data = transforming_client_data.TransformingClientData(
            client_data, _test_transform, num_transformed_clients)
        expansion_factor = num_transformed_clients // len(TEST_DATA)

        tf_dataset = transformed_client_data.create_tf_dataset_from_all_clients(
        )
        self.assertIsInstance(tf_dataset, tf.data.Dataset)

        expected_examples = []
        for expected_data in six.itervalues(TEST_DATA):
            for index in range(expansion_factor):
                for i in range(len(expected_data['x'])):
                    example = {
                        k: v[i]
                        for k, v in six.iteritems(expected_data)
                    }
                    example['x'] += 10 * index
                    expected_examples.append(example)

        for actual in tf_dataset:
            expected = expected_examples.pop(0)
            actual = tf.contrib.framework.nest.map_structure(
                lambda t: t.numpy(), actual)
            self.assertCountEqual(actual, expected)
        self.assertEmpty(expected_examples)
コード例 #2
0
ファイル: load_data.py プロジェクト: xiaowei0202/federated
def infinite_emnist(emnist_client_data, num_pseudo_clients):
    """Converts a Federated EMNIST dataset into an Infinite Federated EMNIST set.

  Infinite Federated EMNIST expands each writer from the EMNIST dataset into
  some number of pseudo-clients each of whose characters are the same but apply
  a fixed random affine transformation to the original user's characters. The
  distribution over affine transformation is approximately equivalent to the one
  described at https://www.cs.toronto.edu/~tijmen/affNIST/. It applies the
  following transformations in this order:

    1. A random rotation chosen uniformly between -20 and 20 degrees.
    2. A random shearing adding between -0.2 to 0.2 of the x coordinate to the
       y coordinate (after centering).
    3. A random scaling between 0.8 and 1.25 (sampled log uniformly).
    4. A random translation between -5 and 5 pixels in both the x and y axes.

  Args:
    emnist_client_data: The `tff.simulation.ClientData` to convert.
    num_pseudo_clients: How many pseudo-clients to generate for each real
      client. Each pseudo-client is formed by applying a given random affine
      transformation to the characters written by a given real user. The first
      pseudo-client for a given user applies the identity transformation, so the
      original users are always included.

  Returns:
    An expanded `tff.simulation.ClientData`.
  """
    num_client_ids = len(emnist_client_data.client_ids)

    return transforming_client_data.TransformingClientData(
        raw_client_data=emnist_client_data,
        make_transform_fn=_make_transform_fn,
        num_transformed_clients=(num_client_ids * num_pseudo_clients))
コード例 #3
0
    def test_create_tf_dataset_for_client(self):
        client_data = hdf5_client_data.HDF5ClientData(
            TransformingClientDataTest.test_data_filepath)

        transformed_client_data = transforming_client_data.TransformingClientData(
            client_data, _test_transform, 9)

        for client_id in transformed_client_data.client_ids:
            tf_dataset = transformed_client_data.create_tf_dataset_for_client(
                client_id)
            self.assertIsInstance(tf_dataset, tf.data.Dataset)

            pattern = r'^(.*)_(\d*)$'
            match = re.search(pattern, client_id)
            client = match.group(1)
            index = int(match.group(2))
            for i, actual in enumerate(tf_dataset):
                expected = {
                    k: v[i]
                    for k, v in six.iteritems(TEST_DATA[client])
                }
                expected['x'] = expected['x'] + 10 * index
                self.assertCountEqual(actual, expected)
                for k, v in six.iteritems(actual):
                    self.assertAllEqual(v.numpy(), expected[k])
コード例 #4
0
    def test_create_tf_dataset_from_all_clients(self):
        client_data = hdf5_client_data.HDF5ClientData(
            TransformingClientDataTest.test_data_filepath)

        num_transformed_clients = 9
        transformed_client_data = transforming_client_data.TransformingClientData(
            client_data, _test_transform_cons, num_transformed_clients)
        expansion_factor = num_transformed_clients // len(TEST_DATA)

        tf_dataset = transformed_client_data.create_tf_dataset_from_all_clients(
        )
        self.assertIsInstance(tf_dataset, tf.data.Dataset)

        expected_examples = []
        for expected_data in TEST_DATA.values():
            for index in range(expansion_factor):
                for i in range(len(expected_data['x'])):
                    example = {
                        k: v[i].copy()
                        for k, v in expected_data.items()
                    }
                    example['x'] += 10 * index
                    expected_examples.append(example)

        for actual in tf_dataset:
            actual = self.evaluate(actual)
            expected = expected_examples.pop(0)
            self.assertCountEqual(actual, expected)
        self.assertEmpty(expected_examples)
コード例 #5
0
 def test_default_num_transformed_clients(self):
   client_data = from_tensor_slices_client_data.FromTensorSlicesClientData(
       TEST_DATA)
   transformed_client_data = transforming_client_data.TransformingClientData(
       client_data, _test_transform_cons)
   client_ids = transformed_client_data.client_ids
   self.assertLen(client_ids, len(TEST_DATA))
   self.assertFalse(transformed_client_data._has_pseudo_clients)
コード例 #6
0
 def test_client_ids_property(self):
     client_data = hdf5_client_data.HDF5ClientData(
         TransformingClientDataTest.test_data_filepath)
     expansion_factor = 2.5
     transformed_client_data = transforming_client_data.TransformingClientData(
         client_data, lambda: 0, expansion_factor)
     self.assertLen(transformed_client_data.client_ids,
                    int(len(TEST_DATA) * expansion_factor))
コード例 #7
0
 def test_client_ids_property(self):
   client_data = from_tensor_slices_client_data.FromTensorSlicesClientData(
       TEST_DATA)
   num_transformed_clients = 7
   transformed_client_data = transforming_client_data.TransformingClientData(
       client_data, _test_transform_cons, num_transformed_clients)
   client_ids = transformed_client_data.client_ids
   self.assertLen(client_ids, num_transformed_clients)
   for client_id in client_ids:
     self.assertIsInstance(client_id, str)
   self.assertListEqual(client_ids, sorted(client_ids))
   self.assertTrue(transformed_client_data._has_pseudo_clients)
コード例 #8
0
    def test_client_ids_property(self):
        client_data = from_tensor_slices_client_data.FromTensorSlicesClientData(
            TEST_DATA)

        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter('always')
            transforming_client_data.TransformingClientData(
                client_data, _test_transform_cons)
            self.assertNotEmpty(w)
            self.assertEqual(w[0].category, DeprecationWarning)
            self.assertRegex(
                str(w[0].message),
                'tff.simulation.TransformingClientData is deprecated')
コード例 #9
0
 def test_client_ids_property(self):
     client_data = from_tensor_slices_client_data.FromTensorSlicesClientData(
         TEST_DATA)
     num_transformed_clients = 7
     transformed_client_data = transforming_client_data.TransformingClientData(
         client_data, _test_transform_cons, num_transformed_clients)
     client_ids = transformed_client_data.client_ids
     # Check length of client_ids.
     self.assertLen(client_ids, 7)
     # Check that they are all strings.
     for client_id in client_ids:
         self.assertIsInstance(client_id, str)
     # Check ids are sorted.
     self.assertListEqual(client_ids, sorted(client_ids))
コード例 #10
0
 def test_fail_on_bad_client_id(self):
   client_data = from_tensor_slices_client_data.FromTensorSlicesClientData(
       TEST_DATA)
   transformed_client_data = transforming_client_data.TransformingClientData(
       client_data, _test_transform_cons, 7)
   # The following three should be valid.
   transformed_client_data.create_tf_dataset_for_client('CLIENT A_1')
   transformed_client_data.create_tf_dataset_for_client('CLIENT B_1')
   transformed_client_data.create_tf_dataset_for_client('CLIENT A_2')
   # This should not be valid: no corresponding client.
   with self.assertRaisesRegex(
       ValueError, 'client_id must be a valid string from client_ids.'):
     transformed_client_data.create_tf_dataset_for_client('CLIENT D_0')
   # This should not be valid: index out of range.
   with self.assertRaisesRegex(
       ValueError, 'client_id must be a valid string from client_ids.'):
     transformed_client_data.create_tf_dataset_for_client('CLIENT B_2')
コード例 #11
0
    def test_client_ids_property(self):
        client_data = hdf5_client_data.HDF5ClientData(
            TransformingClientDataTest.test_data_filepath)
        num_transformed_clients = 7
        transformed_client_data = transforming_client_data.TransformingClientData(
            client_data, _test_transform, num_transformed_clients)
        client_ids = transformed_client_data.client_ids

        # Check length of client_ids.
        self.assertLen(client_ids, 7)

        # Check that they are all strings.
        for client_id in client_ids:
            self.assertIsInstance(client_id, str)

        # Check ids are sorted.
        self.assertListEqual(client_ids, sorted(client_ids))
コード例 #12
0
 def test_create_tf_dataset_for_client(self):
   client_data = from_tensor_slices_client_data.FromTensorSlicesClientData(
       TEST_DATA)
   transformed_client_data = transforming_client_data.TransformingClientData(
       client_data, _test_transform_cons, 9)
   for client_id in transformed_client_data.client_ids:
     tf_dataset = transformed_client_data.create_tf_dataset_for_client(
         client_id)
     self.assertIsInstance(tf_dataset, tf.data.Dataset)
     pattern = r'^(.*)_(\d*)$'
     match = re.search(pattern, client_id)
     client = match.group(1)
     index = int(match.group(2))
     for i, actual in enumerate(tf_dataset):
       actual = self.evaluate(actual)
       expected = {k: v[i].copy() for k, v in TEST_DATA[client].items()}
       expected['x'] += 10 * index
       self.assertCountEqual(actual, expected)
       for k, v in actual.items():
         self.assertAllEqual(v, expected[k])
コード例 #13
0
    def test_fail_on_bad_client_id(self):
        client_data = hdf5_client_data.HDF5ClientData(
            TransformingClientDataTest.test_data_filepath)
        transformed_client_data = transforming_client_data.TransformingClientData(
            client_data, _test_transform, 7)

        # The following three should be valid.
        transformed_client_data.create_tf_dataset_for_client('CLIENT A_1')
        transformed_client_data.create_tf_dataset_for_client('CLIENT B_1')
        transformed_client_data.create_tf_dataset_for_client('CLIENT A_2')

        # This should not be valid: no corresponding client.
        with self.assertRaisesRegex(
                ValueError,
                'client_id must be a valid string from client_ids.'):
            transformed_client_data.create_tf_dataset_for_client('CLIENT D_0')

        # This should not be valid: index out of range.
        with self.assertRaisesRegex(
                ValueError,
                'client_id must be a valid string from client_ids.'):
            transformed_client_data.create_tf_dataset_for_client('CLIENT B_2')