Example #1
0
def shuffle_repeat_batch_federated_data(
        fd: FederatedData,
        batch_size: int,
        client_buffer_size: int,
        example_buffer_size: int,
        seed: Optional[int] = None) -> Iterator[client_datasets.Examples]:
    """Shuffle-repeat-batch all client datasets in a federated dataset for training a centralized baseline.

  Shuffling is done using two levels of buffered shuffling, first at the client
  level, then at the example level.

  This produces an infinite stream of batches. itertools.islice() can be used to
  cap the number of batches, if so desired.

  Args:
    fd: Federated dataset.
    batch_size: Desired batch size.
    client_buffer_size: Buffer size for client level shuffling.
    example_buffer_size: Buffer size for example level shuffling.
    seed: Optional RNG seed.

  Yields:
    Batches of preprocessed examples.
  """
    rng = np.random.RandomState(seed)
    datasets = (client_dataset for _, client_dataset in fd.shuffled_clients(
        client_buffer_size, rng.randint(1 << 32)))
    yield from client_datasets.buffered_shuffle_batch_client_datasets(
        datasets,
        batch_size=batch_size,
        buffer_size=example_buffer_size,
        rng=rng)
Example #2
0
 def test_multi(self):
   batches = list(
       client_datasets.buffered_shuffle_batch_client_datasets(
           [
               client_datasets.ClientDataset({'x': np.arange(10)}),
               client_datasets.ClientDataset({'x': np.arange(10, 11)}),
               client_datasets.ClientDataset({'x': np.arange(11, 15)}),
               client_datasets.ClientDataset({'x': np.arange(15, 17)})
           ],
           batch_size=4,
           buffer_size=16,
           rng=np.random.RandomState(0)))
   self.assertLen(batches, 5)
   npt.assert_equal(batches[0], {
       'x': [1, 6, 16, 8],
   })
   npt.assert_equal(batches[1], {
       'x': [9, 13, 4, 2],
   })
   npt.assert_equal(batches[2], {
       'x': [14, 10, 7, 15],
   })
   npt.assert_equal(batches[3], {
       'x': [11, 3, 0, 5],
   })
   npt.assert_equal(batches[4], {
       'x': [12],
   })
Example #3
0
 def test_single_buffer_1(self):
   batches = list(
       client_datasets.buffered_shuffle_batch_client_datasets(
           [client_datasets.ClientDataset({'x': np.arange(6)})],
           batch_size=5,
           buffer_size=1,
           rng=np.random.RandomState(0)))
   self.assertLen(batches, 2)
   # No shuffling.
   npt.assert_equal(batches[0], {'x': np.arange(5)})
   npt.assert_equal(batches[1], {'x': [5]})
Example #4
0
 def test_different_features(self):
   with self.assertRaisesRegex(
       ValueError, 'client_datasets should have identical features'):
     list(
         client_datasets.buffered_shuffle_batch_client_datasets(
             [
                 client_datasets.ClientDataset({'x': np.arange(10)}),
                 client_datasets.ClientDataset({'y': np.arange(10, 11)})
             ],
             batch_size=4,
             buffer_size=16,
             rng=np.random.RandomState(0)))
Example #5
0
 def test_single_buffer_4(self):
   batches = list(
       client_datasets.buffered_shuffle_batch_client_datasets(
           [client_datasets.ClientDataset({'x': np.arange(8)})],
           batch_size=6,
           buffer_size=4,
           rng=np.random.RandomState(0)))
   self.assertLen(batches, 2)
   npt.assert_equal(batches[0], {
       'x': [2, 4, 5, 6, 7, 3],
   })
   npt.assert_equal(batches[1], {
       'x': [1, 0],
   })
Example #6
0
 def test_different_preprocessors(self):
   with self.assertRaisesRegex(
       ValueError,
       'client_datasets should have the identical Preprocessor object'):
     list(
         client_datasets.buffered_shuffle_batch_client_datasets(
             [
                 client_datasets.ClientDataset(
                     {'x': np.arange(10, 20)},
                     client_datasets.BatchPreprocessor()),
                 client_datasets.ClientDataset(
                     {'x': np.arange(20, 30)},
                     client_datasets.BatchPreprocessor())
             ],
             batch_size=4,
             buffer_size=16,
             rng=np.random.RandomState(0)))
Example #7
0
 def test_preprocessor(self):
   batches = list(
       client_datasets.buffered_shuffle_batch_client_datasets(
           [
               client_datasets.ClientDataset({'x': np.arange(6)},
                                             client_datasets.BatchPreprocessor(
                                                 [lambda x: {
                                                     'x': x['x'] + 1
                                                 }]))
           ],
           batch_size=5,
           buffer_size=16,
           rng=np.random.RandomState(0)))
   self.assertLen(batches, 2)
   npt.assert_equal(batches[0], {
       'x': [6, 3, 2, 4, 1],
   })
   npt.assert_equal(batches[1], {
       'x': [5],
   })
Example #8
0
 def test_empty(self):
   batches = list(
       client_datasets.buffered_shuffle_batch_client_datasets(
           [], batch_size=5, buffer_size=10, rng=np.random.RandomState(0)))
   self.assertListEqual(batches, [])