def test_shuffle_client_ids(self):
        tensor_slices_dict = {
            'a': [1, 1],
            'b': [2, 2, 2],
            'c': [3],
            'd': [4, 4]
        }
        all_examples = [1, 1, 2, 2, 2, 3, 4, 4]
        client_data = from_tensor_slices_client_data.FromTensorSlicesClientData(
            tensor_slices_dict)

        def get_flat_dataset(seed):
            ds = client_data.create_tf_dataset_from_all_clients(seed=seed)
            return [x.numpy() for x in ds]

        d1 = get_flat_dataset(123)
        d2 = get_flat_dataset(456)
        self.assertNotEqual(d1, d2)  # Different random seeds, different order.
        self.assertCountEqual(d1, all_examples)
        self.assertCountEqual(d2, all_examples)

        # Test that the default behavior is to use a fresh random seed.
        # We could get unlucky, but we are very unlikely to get unlucky
        # 100 times in a row.
        found_not_equal = False
        for _ in range(100):
            if get_flat_dataset(seed=None) != get_flat_dataset(seed=None):
                found_not_equal = True
                break
        self.assertTrue(found_not_equal)
  def test_dataset_computation_where_client_data_is_tuples(self):
    client_data = from_tensor_slices_client_data.FromTensorSlicesClientData(
        TEST_DATA_WITH_TUPLES)

    dataset_computation = client_data.dataset_computation
    self.assertIsInstance(dataset_computation, computation_base.Computation)

    expected_dataset_comp_type_signature = computation_types.FunctionType(
        computation_types.to_type(tf.string),
        computation_types.SequenceType(
            computation_types.TensorType(
                client_data.element_type_structure[0].dtype,
                tf.TensorShape(None))))

    self.assertTrue(
        dataset_computation.type_signature.is_equivalent_to(
            expected_dataset_comp_type_signature))

    # Iterate over each client, invoking the dataset_computation and ensuring
    # we received a tf.data.Dataset with the correct data.
    for client_id, expected_data in TEST_DATA_WITH_TUPLES.items():
      tf_dataset = dataset_computation(client_id)
      self.assertIsInstance(tf_dataset, tf.data.Dataset)
      self.assertLen(expected_data, tf_dataset.cardinality())
      # Check that everything in tf_dataset is an exact match for the contents
      # of expected_data at the corresponding index.
      for expected, actual in zip(expected_data, tf_dataset):
        self.assertAllEqual(np.asarray(expected), actual.numpy())
 def test_create_tf_dataset_from_all_clients(self):
     client_data = from_tensor_slices_client_data.FromTensorSlicesClientData(
         TEST_DATA)
     num_transformed_clients = 9
     transformed_client_data = transforming_client_data.TransformingClientData(
         client_data, _test_transform_cons, num_transformed_clients)
     expansion_factor = num_transformed_clients // len(TEST_DATA)
     tf_dataset = transformed_client_data.create_tf_dataset_from_all_clients(
     )
     self.assertIsInstance(tf_dataset, tf.data.Dataset)
     expected_examples = []
     for expected_data in TEST_DATA.values():
         for index in range(expansion_factor):
             for i in range(len(expected_data['x'])):
                 example = {
                     k: v[i].copy()
                     for k, v in expected_data.items()
                 }
                 example['x'] += 10 * index
                 expected_examples.append(example)
     for actual in tf_dataset:
         actual = self.evaluate(actual)
         expected = expected_examples.pop(0)
         self.assertCountEqual(actual, expected)
     self.assertEmpty(expected_examples)
  def test_dataset_computation_raises_error_if_unknown_client_id(self):
    client_data = from_tensor_slices_client_data.FromTensorSlicesClientData(
        TEST_DATA)

    dataset_computation = client_data.dataset_computation

    with self.assertRaises(tf.errors.InvalidArgumentError):
      dataset_computation(CLIENT_ID_NOT_IN_TEST_DATA)
示例#5
0
 def test_default_num_transformed_clients(self):
   client_data = from_tensor_slices_client_data.FromTensorSlicesClientData(
       TEST_DATA)
   transformed_client_data = transforming_client_data.TransformingClientData(
       client_data, _test_transform_cons)
   client_ids = transformed_client_data.client_ids
   self.assertLen(client_ids, len(TEST_DATA))
   self.assertFalse(transformed_client_data._has_pseudo_clients)
 def test_deprecation_warning_raised_on_init(self):
     tensor_slices_dict = {'a': [1, 2, 3], 'b': [4, 5]}
     with warnings.catch_warnings(record=True) as w:
         warnings.simplefilter('always')
         from_tensor_slices_client_data.FromTensorSlicesClientData(
             tensor_slices_dict)
         self.assertNotEmpty(w)
         self.assertEqual(w[0].category, DeprecationWarning)
         self.assertRegex(
             str(w[0].message),
             'tff.simulation.FromTensorSlicesClientData is deprecated')
示例#7
0
 def test_client_ids_property(self):
   client_data = from_tensor_slices_client_data.FromTensorSlicesClientData(
       TEST_DATA)
   num_transformed_clients = 7
   transformed_client_data = transforming_client_data.TransformingClientData(
       client_data, _test_transform_cons, num_transformed_clients)
   client_ids = transformed_client_data.client_ids
   self.assertLen(client_ids, num_transformed_clients)
   for client_id in client_ids:
     self.assertIsInstance(client_id, str)
   self.assertListEqual(client_ids, sorted(client_ids))
   self.assertTrue(transformed_client_data._has_pseudo_clients)
  def test_where_client_data_is_tuples(self):
    client_data = from_tensor_slices_client_data.FromTensorSlicesClientData(
        TEST_DATA_WITH_TUPLES)
    self.assertCountEqual(TEST_DATA_WITH_TUPLES.keys(), client_data.client_ids)

    self.assertEqual(client_data.element_type_structure, (tf.TensorSpec(
        shape=(), dtype=tf.int32), tf.TensorSpec(shape=(), dtype=tf.int32)))

    for client_id in TEST_DATA_WITH_TUPLES:
      self.assertSameDatasets(
          tf.data.Dataset.from_tensor_slices(TEST_DATA_WITH_TUPLES[client_id]),
          client_data.create_tf_dataset_for_client(client_id))
示例#9
0
def get_synthetic() -> client_data.ClientData:
    """Creates `tff.simulation.ClientData` for a synthetic in-memory example of Shakespeare.

  The returned `tff.simulation.ClientData` will have the same data schema as
  `load_data()`, but uses a very small set of client data loaded in-memory.

  This synthetic data is useful for validation in small tests.

  Returns:
    A `tff.simulation.ClientData` of synthentic Shakespeare text.
  """
    return from_tensor_slices_client_data.FromTensorSlicesClientData(
        _SYNTHETIC_SHAKESPEARE_DATA)
 def test_client_ids_property(self):
     client_data = from_tensor_slices_client_data.FromTensorSlicesClientData(
         TEST_DATA)
     num_transformed_clients = 7
     transformed_client_data = transforming_client_data.TransformingClientData(
         client_data, _test_transform_cons, num_transformed_clients)
     client_ids = transformed_client_data.client_ids
     # Check length of client_ids.
     self.assertLen(client_ids, 7)
     # Check that they are all strings.
     for client_id in client_ids:
         self.assertIsInstance(client_id, str)
     # Check ids are sorted.
     self.assertListEqual(client_ids, sorted(client_ids))
示例#11
0
def get_synthetic():
    """Returns a small synthetic dataset for testing.

  Provides two clients, each client with only 3 examples. The examples are
  derived from a fixed set of examples in the larger dataset, but are not exact
  copies.

  Returns:
     A `tff.simulation.ClientData` object that matches the characteristics
     (other than size) of those provided by
     `tff.simulation.datasets.stackoverflow.load_data`.
  """
    return from_tensor_slices_client_data.FromTensorSlicesClientData(
        _SYNTHETIC_STACKOVERFLOW_DATA)
示例#12
0
    def test_basic(self):
        tensor_slices_dict = {'a': [1, 2, 3], 'b': [4, 5]}
        client_data = from_tensor_slices_client_data.FromTensorSlicesClientData(
            tensor_slices_dict)
        self.assertCountEqual(client_data.client_ids, ['a', 'b'])
        self.assertEqual(client_data.output_types, tf.int32)
        self.assertEqual(client_data.output_shapes, ())

        def as_list(dataset):
            return [x.numpy() for x in dataset]

        self.assertEqual(
            as_list(client_data.create_tf_dataset_for_client('a')), [1, 2, 3])
        self.assertEqual(
            as_list(client_data.create_tf_dataset_for_client('b')), [4, 5])
    def test_basic(self):
        tensor_slices_dict = {'a': [1, 2, 3], 'b': [4, 5]}
        client_data = from_tensor_slices_client_data.FromTensorSlicesClientData(
            tensor_slices_dict)
        self.assertCountEqual(client_data.client_ids, ['a', 'b'])
        self.assertEqual(client_data.element_type_structure,
                         tf.TensorSpec(shape=(), dtype=tf.int32))

        def as_list(dataset):
            return [self.evaluate(x) for x in dataset]

        self.assertEqual(
            as_list(client_data.create_tf_dataset_for_client('a')), [1, 2, 3])
        self.assertEqual(
            as_list(client_data.create_tf_dataset_for_client('b')), [4, 5])
示例#14
0
 def test_fail_on_bad_client_id(self):
   client_data = from_tensor_slices_client_data.FromTensorSlicesClientData(
       TEST_DATA)
   transformed_client_data = transforming_client_data.TransformingClientData(
       client_data, _test_transform_cons, 7)
   # The following three should be valid.
   transformed_client_data.create_tf_dataset_for_client('CLIENT A_1')
   transformed_client_data.create_tf_dataset_for_client('CLIENT B_1')
   transformed_client_data.create_tf_dataset_for_client('CLIENT A_2')
   # This should not be valid: no corresponding client.
   with self.assertRaisesRegex(
       ValueError, 'client_id must be a valid string from client_ids.'):
     transformed_client_data.create_tf_dataset_for_client('CLIENT D_0')
   # This should not be valid: index out of range.
   with self.assertRaisesRegex(
       ValueError, 'client_id must be a valid string from client_ids.'):
     transformed_client_data.create_tf_dataset_for_client('CLIENT B_2')
  def test_where_client_data_is_ordered_dicts(self):
    client_data = from_tensor_slices_client_data.FromTensorSlicesClientData(
        TEST_DATA_WITH_ORDEREDDICTS)
    self.assertCountEqual(TEST_DATA_WITH_ORDEREDDICTS.keys(),
                          client_data.client_ids)
    self.assertEqual(
        collections.OrderedDict([
            ('x', tf.TensorSpec(shape=(2,), dtype=tf.int32)),
            ('y', tf.TensorSpec(shape=(), dtype=tf.float32)),
            ('z', tf.TensorSpec(shape=(), dtype=tf.string))
        ]), client_data.element_type_structure)

    for client_id in TEST_DATA_WITH_ORDEREDDICTS:
      self.assertSameDatasetsOfDicts(
          tf.data.Dataset.from_tensor_slices(
              TEST_DATA_WITH_ORDEREDDICTS[client_id]),
          client_data.create_tf_dataset_for_client(client_id))
示例#16
0
 def test_create_tf_dataset_for_client(self):
   client_data = from_tensor_slices_client_data.FromTensorSlicesClientData(
       TEST_DATA)
   transformed_client_data = transforming_client_data.TransformingClientData(
       client_data, _test_transform_cons, 9)
   for client_id in transformed_client_data.client_ids:
     tf_dataset = transformed_client_data.create_tf_dataset_for_client(
         client_id)
     self.assertIsInstance(tf_dataset, tf.data.Dataset)
     pattern = r'^(.*)_(\d*)$'
     match = re.search(pattern, client_id)
     client = match.group(1)
     index = int(match.group(2))
     for i, actual in enumerate(tf_dataset):
       actual = self.evaluate(actual)
       expected = {k: v[i].copy() for k, v in TEST_DATA[client].items()}
       expected['x'] += 10 * index
       self.assertCountEqual(actual, expected)
       for k, v in actual.items():
         self.assertAllEqual(v, expected[k])
示例#17
0
def get_synthetic(num_clients=2):
  """Quickly returns a small synthetic dataset, useful for unit tests, etc.

  Each client produced has exactly 10 examples, one of each digit. The images
  are derived from a fixed set of hard-coded images, and transformed using
  `tff.simulation.datasets.emnist.infinite_emnist` to produce the desired
  number of clients.

  Args:
    num_clients: The number of synthetic clients to generate.

  Returns:
     A `tff.simulation.ClientData` object that matches the characteristics
     (other than size) of those provided by
     `tff.simulation.datasets.emnist.load_data`.
  """
  return get_infinite(
      # Base ClientData with one client
      from_tensor_slices_client_data.FromTensorSlicesClientData(
          {'synthetic': _get_synthetic_digits_data()}),
      num_pseudo_clients=num_clients)
  def test_dataset_computation_where_client_data_is_ordered_dicts(self):
    client_data = from_tensor_slices_client_data.FromTensorSlicesClientData(
        TEST_DATA_WITH_ORDEREDDICTS)

    dataset_computation = client_data.dataset_computation
    self.assertIsInstance(dataset_computation, computation_base.Computation)

    expected_dataset_comp_type_signature = computation_types.FunctionType(
        computation_types.to_type(tf.string),
        computation_types.SequenceType(
            collections.OrderedDict([
                ('x',
                 computation_types.TensorType(
                     client_data.element_type_structure['x'].dtype,
                     tf.TensorShape(2))),
                ('y',
                 computation_types.TensorType(
                     client_data.element_type_structure['y'].dtype, None)),
                ('z',
                 computation_types.TensorType(
                     client_data.element_type_structure['z'].dtype, None))
            ])))

    self.assertTrue(
        dataset_computation.type_signature.is_equivalent_to(
            expected_dataset_comp_type_signature))

    # Iterate over each client, invoking the computation and ensuring
    # we received a tf.data.Dataset with the correct data.
    for client_id in TEST_DATA_WITH_ORDEREDDICTS:
      dataset = dataset_computation(client_id)
      self.assertIsInstance(dataset, tf.data.Dataset)

      expected_dataset = tf.data.Dataset.from_tensor_slices(
          TEST_DATA_WITH_ORDEREDDICTS[client_id])
      self.assertSameDatasetsOfDicts(expected_dataset, dataset)
 def test_init_raises_error_if_slices_is_not_dict(self):
   with self.assertRaises(TypeError):
     from_tensor_slices_client_data.FromTensorSlicesClientData(
         TEST_DATA_NOT_DICT)
 def test_init_raises_error_if_slices_are_namedtuples(self):
   with self.assertRaises(TypeError):
     from_tensor_slices_client_data.FromTensorSlicesClientData(
         TEST_DATA_WITH_NAMEDTUPLES)
 def test_empty(self):
     with self.assertRaises(ValueError):
         from_tensor_slices_client_data.FromTensorSlicesClientData(
             {'a': []})
 def test_init_raises_error_if_slices_are_inconsistent_type(self):
   with self.assertRaises(TypeError):
     from_tensor_slices_client_data.FromTensorSlicesClientData(
         TEST_DATA_WITH_INCONSISTENT_TYPE)
 def test_init_raises_error_if_slices_are_part_list_and_part_dict(self):
   with self.assertRaises(TypeError):
     from_tensor_slices_client_data.FromTensorSlicesClientData(
         TEST_DATA_WITH_PART_LIST_AND_PART_DICT)