def test_client_data_constructs_with_correct_clients_and_types(self):
     tensor_slices_dict = {'a': [1, 2, 3], 'b': [4, 5]}
     client_data = from_tensor_slices_client_data.TestClientData(
         tensor_slices_dict)
     self.assertCountEqual(client_data.client_ids, ['a', 'b'])
     self.assertEqual(client_data.element_type_structure,
                      tf.TensorSpec(shape=(), dtype=tf.int32))
    def test_build_synthethic_iid_client_data(self):
        # Create a fake, very non-IID ClientData.
        client_datasets = collections.OrderedDict(a=[1] * 3,
                                                  b=[2] * 5,
                                                  c=[3] * 7)
        non_iid_client_data = from_tensor_slices_client_data.TestClientData(
            client_datasets)

        iid_client_data_iter = iter(
            dataset_utils.build_synthethic_iid_datasets(non_iid_client_data,
                                                        client_dataset_size=5))

        num_synthethic_clients = 3
        run_results = []
        for _ in range(5):
            actual_iid_client_datasets = []
            for _ in range(num_synthethic_clients):
                dataset = next(iid_client_data_iter)
                actual_iid_client_datasets.append(
                    [self.evaluate(x) for x in dataset])
            # We expect 3 datasets: 15 examples in the global dataset, synthetic
            # non-iid configured for 5 examples per client.
            self.assertEqual([5, 5, 5],
                             [len(d) for d in actual_iid_client_datasets])
            run_results.append(actual_iid_client_datasets)

        # Assert no run is the same. The chance that two runs are the same is far
        # less than 1 in a million, flakes should be imperceptible.
        for i, run_a in enumerate(run_results[:-1]):
            for run_b in run_results[i + 1:]:
                self.assertNotEqual(run_a, run_b, msg=str(run_results))
Beispiel #3
0
  def test_dataset_computation_where_client_data_is_tuples(self):
    client_data = from_tensor_slices_client_data.TestClientData(
        TEST_DATA_WITH_TUPLES)

    dataset_computation = client_data.dataset_computation
    self.assertIsInstance(dataset_computation, computation_base.Computation)

    expected_dataset_comp_type_signature = computation_types.FunctionType(
        computation_types.to_type(tf.string),
        computation_types.SequenceType(
            computation_types.TensorType(
                client_data.element_type_structure[0].dtype,
                tf.TensorShape(None))))

    self.assertTrue(
        dataset_computation.type_signature.is_equivalent_to(
            expected_dataset_comp_type_signature))

    # Iterate over each client, invoking the dataset_computation and ensuring
    # we received a tf.data.Dataset with the correct data.
    for client_id, expected_data in TEST_DATA_WITH_TUPLES.items():
      tf_dataset = dataset_computation(client_id)
      self.assertIsInstance(tf_dataset, tf.data.Dataset)
      self.assertLen(expected_data, tf_dataset.cardinality())
      # Check that everything in tf_dataset is an exact match for the contents
      # of expected_data at the corresponding index.
      for expected, actual in zip(expected_data, tf_dataset):
        self.assertAllEqual(np.asarray(expected), actual.numpy())
    def test_create_tf_dataset_from_all_clients(self):

        # Expands `CLIENT {N}` into N clients which add range(N) to the feature.
        def expand_client_id(client_id):
            return [
                client_id + '-' + str(i) for i in range(int(client_id[-1]))
            ]

        def make_transform_fn(client_id):
            split_client_id = tf.strings.split(client_id, '-')
            index = tf.cast(tf.strings.to_number(split_client_id[1]), tf.int32)
            return lambda x: x + index

        reduce_client_id = lambda client_id: tf.strings.split(client_id,
                                                              sep='-')[0]

        # pyformat: disable
        raw_data = {
            'CLIENT 1': [0],  # expanded to [0]
            'CLIENT 2': [1, 3, 5],  # expanded to [1, 3, 5], [2, 4, 6]
            'CLIENT 3': [7, 10]  # expanded to [7, 10], [8, 11], [9, 12]
        }
        # pyformat: enable
        client_data = from_tensor_slices_client_data.TestClientData(raw_data)
        transformed_client_data = transforming_client_data.TransformingClientData(
            client_data, make_transform_fn, expand_client_id, reduce_client_id)

        flat_data = transformed_client_data.create_tf_dataset_from_all_clients(
        )
        self.assertIsInstance(flat_data, tf.data.Dataset)
        all_features = [batch.numpy() for batch in flat_data]
        self.assertCountEqual(all_features, range(13))
Beispiel #5
0
  def test_dataset_computation_raises_error_if_unknown_client_id(self):
    client_data = from_tensor_slices_client_data.TestClientData(TEST_DATA)

    dataset_computation = client_data.dataset_computation

    with self.assertRaises(tf.errors.InvalidArgumentError):
      dataset_computation(CLIENT_ID_NOT_IN_TEST_DATA)
Beispiel #6
0
 def test_default_num_transformed_clients(self):
     client_data = from_tensor_slices_client_data.TestClientData(TEST_DATA)
     transformed_client_data = transforming_client_data.TransformingClientData(
         client_data, _test_transform_cons)
     client_ids = transformed_client_data.client_ids
     self.assertLen(client_ids, len(TEST_DATA))
     self.assertFalse(transformed_client_data._has_pseudo_clients)
Beispiel #7
0
  def test_shuffle_client_ids(self):
    tensor_slices_dict = {'a': [1, 1], 'b': [2, 2, 2], 'c': [3], 'd': [4, 4]}
    all_examples = [1, 1, 2, 2, 2, 3, 4, 4]
    client_data = from_tensor_slices_client_data.TestClientData(
        tensor_slices_dict)

    def get_flat_dataset(seed):
      ds = client_data.create_tf_dataset_from_all_clients(seed=seed)
      return [x.numpy() for x in ds]

    d1 = get_flat_dataset(123)
    d2 = get_flat_dataset(456)
    self.assertNotEqual(d1, d2)  # Different random seeds, different order.
    self.assertCountEqual(d1, all_examples)
    self.assertCountEqual(d2, all_examples)

    # Test that the default behavior is to use a fresh random seed.
    # We could get unlucky, but we are very unlikely to get unlucky
    # 100 times in a row.
    found_not_equal = False
    for _ in range(100):
      if get_flat_dataset(seed=None) != get_flat_dataset(seed=None):
        found_not_equal = True
        break
    self.assertTrue(found_not_equal)
    def test_serializable_dataset_fn_constructs(self):
        tensor_slices_dict = {'a': [1, 2, 3], 'b': [4, 5]}
        client_data = from_tensor_slices_client_data.TestClientData(
            tensor_slices_dict)

        def as_list(dataset):
            return [self.evaluate(x) for x in dataset]

        self.assertEqual(as_list(client_data.serializable_dataset_fn('a')),
                         [1, 2, 3])
Beispiel #9
0
 def test_client_ids_property(self):
     client_data = from_tensor_slices_client_data.TestClientData(TEST_DATA)
     num_transformed_clients = 7
     transformed_client_data = transforming_client_data.TransformingClientData(
         client_data, _test_transform_cons, num_transformed_clients)
     client_ids = transformed_client_data.client_ids
     self.assertLen(client_ids, num_transformed_clients)
     for client_id in client_ids:
         self.assertIsInstance(client_id, str)
     self.assertListEqual(client_ids, sorted(client_ids))
     self.assertTrue(transformed_client_data._has_pseudo_clients)
Beispiel #10
0
  def test_where_client_data_is_tensors(self):
    client_data = from_tensor_slices_client_data.TestClientData(TEST_DATA)
    self.assertCountEqual(TEST_DATA.keys(), client_data.client_ids)

    self.assertEqual(client_data.element_type_structure,
                     tf.TensorSpec(shape=(2,), dtype=tf.int32))

    for client_id in TEST_DATA:
      self.assertSameDatasets(
          tf.data.Dataset.from_tensor_slices(TEST_DATA[client_id]),
          client_data.create_tf_dataset_for_client(client_id))
Beispiel #11
0
def get_synthetic():
    """Returns a small synthetic dataset for testing.

  The single client produced has exactly 5 examples. The images and labels are
  derived from a fixed set of hard-coded images.

  Returns:
     A `tff.simulation.datasets.ClientData` object that matches the
     characteristics (other than size) of those provided by
     `tff.simulation.datasets.cifar100.load_data`.
  """
    return from_tensor_slices_client_data.TestClientData(
        {'synthetic': _get_synthetic_digits_data()})
Beispiel #12
0
def get_synthetic():
  """Returns a small synthetic dataset for testing.

  Provides two clients, each client with only 3 examples. The examples are
  derived from a fixed set of examples in the larger dataset, but are not exact
  copies.

  Returns:
     A `tff.simulation.datasets.ClientData` object that matches the
     characteristics (other than size) of those provided by
     `tff.simulation.datasets.stackoverflow.load_data`.
  """
  return from_tensor_slices_client_data.TestClientData(
      create_synthetic_data_dictionary())
Beispiel #13
0
def get_synthetic() -> client_data.ClientData:
    """Creates `tff.simulation.datasets.ClientData` for a synthetic in-memory example of Shakespeare.

  The returned `tff.simulation.datasets.ClientData` will have the same data
  schema as `load_data()`, but uses a very small set of client data loaded
  in-memory.

  This synthetic data is useful for validation in small tests.

  Returns:
    A `tff.simulation.datasets.ClientData` of synthentic Shakespeare text.
  """
    return from_tensor_slices_client_data.TestClientData(
        _SYNTHETIC_SHAKESPEARE_DATA)
Beispiel #14
0
  def test_basic(self):
    tensor_slices_dict = {'a': [1, 2, 3], 'b': [4, 5]}
    client_data = from_tensor_slices_client_data.TestClientData(
        tensor_slices_dict)
    self.assertCountEqual(client_data.client_ids, ['a', 'b'])
    self.assertEqual(client_data.element_type_structure,
                     tf.TensorSpec(shape=(), dtype=tf.int32))

    def as_list(dataset):
      return [self.evaluate(x) for x in dataset]

    self.assertEqual(
        as_list(client_data.create_tf_dataset_for_client('a')), [1, 2, 3])
    self.assertEqual(
        as_list(client_data.create_tf_dataset_for_client('b')), [4, 5])
Beispiel #15
0
def get_synthetic() -> ClientData:
    """Returns a small synthetic dataset for testing.

  The single client produced has 3 examples generated pseudo-randomly.

  Returns:
     A `tff.simulation.datasets.ClientData`.
  """
    images = [
        tf.random.stateless_normal(shape=(128, 128, 3), seed=(0, i))
        for i in range(3)
    ]
    images_as_tensor = tf.cast(tf.stack(images, axis=0), dtype=tf.uint8)
    labels = tf.constant([0, 1, 2], dtype=tf.int64)
    data = collections.OrderedDict([('image/decoded', images_as_tensor),
                                    ('class', labels)])
    return from_tensor_slices_client_data.TestClientData({'synthetic': data})
Beispiel #16
0
  def test_where_client_data_is_ordered_dicts(self):
    client_data = from_tensor_slices_client_data.TestClientData(
        TEST_DATA_WITH_ORDEREDDICTS)
    self.assertCountEqual(TEST_DATA_WITH_ORDEREDDICTS.keys(),
                          client_data.client_ids)
    self.assertEqual(
        collections.OrderedDict([
            ('x', tf.TensorSpec(shape=(2,), dtype=tf.int32)),
            ('y', tf.TensorSpec(shape=(), dtype=tf.float32)),
            ('z', tf.TensorSpec(shape=(), dtype=tf.string))
        ]), client_data.element_type_structure)

    for client_id in TEST_DATA_WITH_ORDEREDDICTS:
      self.assertSameDatasetsOfDicts(
          tf.data.Dataset.from_tensor_slices(
              TEST_DATA_WITH_ORDEREDDICTS[client_id]),
          client_data.create_tf_dataset_for_client(client_id))
Beispiel #17
0
 def test_fail_on_bad_client_id(self):
     client_data = from_tensor_slices_client_data.TestClientData(TEST_DATA)
     transformed_client_data = transforming_client_data.TransformingClientData(
         client_data, _test_transform_cons, 7)
     # The following three should be valid.
     transformed_client_data.create_tf_dataset_for_client('CLIENT A_1')
     transformed_client_data.create_tf_dataset_for_client('CLIENT B_1')
     transformed_client_data.create_tf_dataset_for_client('CLIENT A_2')
     # This should not be valid: no corresponding client.
     with self.assertRaisesRegex(
             ValueError,
             'client_id must be a valid string from client_ids.'):
         transformed_client_data.create_tf_dataset_for_client('CLIENT D_0')
     # This should not be valid: index out of range.
     with self.assertRaisesRegex(
             ValueError,
             'client_id must be a valid string from client_ids.'):
         transformed_client_data.create_tf_dataset_for_client('CLIENT B_2')
Beispiel #18
0
 def test_create_tf_dataset_for_client(self):
     client_data = from_tensor_slices_client_data.TestClientData(TEST_DATA)
     transformed_client_data = transforming_client_data.TransformingClientData(
         client_data, _test_transform_cons, 9)
     for client_id in transformed_client_data.client_ids:
         tf_dataset = transformed_client_data.create_tf_dataset_for_client(
             client_id)
         self.assertIsInstance(tf_dataset, tf.data.Dataset)
         pattern = r'^(.*)_(\d*)$'
         match = re.search(pattern, client_id)
         client = match.group(1)
         index = int(match.group(2))
         for i, actual in enumerate(tf_dataset):
             actual = self.evaluate(actual)
             expected = {
                 k: v[i].copy()
                 for k, v in TEST_DATA[client].items()
             }
             expected['x'] += 10 * index
             self.assertCountEqual(actual, expected)
             for k, v in actual.items():
                 self.assertAllEqual(v, expected[k])
    def test_dataset_computation_where_client_data_is_ordered_dicts(self):
        client_data = from_tensor_slices_client_data.TestClientData(
            TEST_DATA_WITH_ORDEREDDICTS)

        dataset_computation = client_data.dataset_computation
        self.assertIsInstance(dataset_computation,
                              computation_base.Computation)

        expected_dataset_comp_type_signature = computation_types.FunctionType(
            computation_types.to_type(tf.string),
            computation_types.SequenceType(
                collections.OrderedDict([
                    ('x',
                     computation_types.TensorType(
                         client_data.element_type_structure['x'].dtype,
                         tf.TensorShape(2))),
                    ('y',
                     computation_types.TensorType(
                         client_data.element_type_structure['y'].dtype, None)),
                    ('z',
                     computation_types.TensorType(
                         client_data.element_type_structure['z'].dtype, None))
                ])))

        self.assertTrue(
            dataset_computation.type_signature.is_equivalent_to(
                expected_dataset_comp_type_signature))

        # Iterate over each client, invoking the computation and ensuring
        # we received a tf.data.Dataset with the correct data.
        for client_id in TEST_DATA_WITH_ORDEREDDICTS:
            dataset = dataset_computation(client_id)
            self.assertIsInstance(dataset, tf.data.Dataset)

            expected_dataset = tf.data.Dataset.from_tensor_slices(
                TEST_DATA_WITH_ORDEREDDICTS[client_id])
            self.assertSameDatasetsOfDicts(expected_dataset, dataset)
Beispiel #20
0
 def test_create_tf_dataset_from_all_clients(self):
     client_data = from_tensor_slices_client_data.TestClientData(TEST_DATA)
     num_transformed_clients = 9
     transformed_client_data = transforming_client_data.TransformingClientData(
         client_data, _test_transform_cons, num_transformed_clients)
     expansion_factor = num_transformed_clients // len(TEST_DATA)
     tf_dataset = transformed_client_data.create_tf_dataset_from_all_clients(
     )
     self.assertIsInstance(tf_dataset, tf.data.Dataset)
     expected_examples = []
     for expected_data in TEST_DATA.values():
         for index in range(expansion_factor):
             for i in range(len(expected_data['x'])):
                 example = {
                     k: v[i].copy()
                     for k, v in expected_data.items()
                 }
                 example['x'] += 10 * index
                 expected_examples.append(example)
     for actual in tf_dataset:
         actual = self.evaluate(actual)
         expected = expected_examples.pop(0)
         self.assertCountEqual(actual, expected)
     self.assertEmpty(expected_examples)
    ),
    'CLIENT B':
    collections.OrderedDict(
        x=[[10, 11]],
        y=[7.0],
        z=['d'],
    ),
    'CLIENT C':
    collections.OrderedDict(
        x=[[100, 101], [200, 201]],
        y=[8.0, 9.0],
        z=['e', 'f'],
    ),
}

TEST_CLIENT_DATA = from_tensor_slices_client_data.TestClientData(TEST_DATA)


def _make_transform_expanded(client_id):
    index_str = tf.strings.split(client_id, sep='_', maxsplit=1)[0]
    index = tf.cast(tf.strings.to_number(index_str), tf.int32)

    def fn(data):
        return collections.OrderedDict([('x', data['x'] + 10 * index),
                                        ('y', data['y']), ('z', data['z'])])

    return fn


def _make_transform_raw(client_id):
    del client_id
 def test_constructor_does_not_modify_in_place(self, tensor_slices_dict):
     copy_of_tensor_slices_dict = copy.deepcopy(tensor_slices_dict)
     from_tensor_slices_client_data.TestClientData(tensor_slices_dict)
     self.assertSameStructure(tensor_slices_dict,
                              copy_of_tensor_slices_dict)
Beispiel #23
0
 def test_init_raises_error_if_slices_are_part_list_and_part_dict(self):
   with self.assertRaises(TypeError):
     from_tensor_slices_client_data.TestClientData(
         TEST_DATA_WITH_PART_LIST_AND_PART_DICT)
Beispiel #24
0
 def test_init_raises_error_if_slices_are_inconsistent_type(self):
   with self.assertRaises(TypeError):
     from_tensor_slices_client_data.TestClientData(
         TEST_DATA_WITH_INCONSISTENT_TYPE)
Beispiel #25
0
 def test_init_raises_error_if_slices_are_namedtuples(self):
   with self.assertRaises(TypeError):
     from_tensor_slices_client_data.TestClientData(TEST_DATA_WITH_NAMEDTUPLES)
Beispiel #26
0
 def test_init_raises_error_if_slices_is_not_dict(self):
   with self.assertRaises(TypeError):
     from_tensor_slices_client_data.TestClientData(TEST_DATA_NOT_DICT)
Beispiel #27
0
 def test_raises_error_if_empty_client_found(self):
   with self.assertRaises(ValueError):
     from_tensor_slices_client_data.TestClientData({'a': []})