Exemplo n.º 1
0
 def test_different_preprocessors(self):
   with self.assertRaisesRegex(
       ValueError,
       'client_datasets should have the identical Preprocessor object'):
     list(
         client_datasets.padded_batch_client_datasets([
             client_datasets.ClientDataset(
                 {'x': np.arange(10)}, client_datasets.BatchPreprocessor()),
             client_datasets.ClientDataset({'x': np.arange(10, 11)},
                                           client_datasets.BatchPreprocessor())
         ],
                                                      batch_size=4))
Exemplo n.º 2
0
 def test_append(self):
   preprocessor = client_datasets.BatchPreprocessor([
       # Flattens `pixels`.
       lambda x: {
           **x, 'pixels': x['pixels'].reshape([-1, 28 * 28])
       },
       # Introduce `binary_label`.
       lambda x: {
           **x, 'binary_label': x['label'] % 2
       },
   ])
   new_preprocessor = preprocessor.append(lambda x: {
       **x, 'sum_pixels': np.sum(x['pixels'], axis=1)
   })
   self.assertIsNot(new_preprocessor, preprocessor)
   fake_emnist = {
       'pixels': np.random.uniform(size=(10, 28, 28)),
       'label': np.random.randint(10, size=(10,))
   }
   result = new_preprocessor(fake_emnist)
   self.assertIs(result['label'], fake_emnist['label'])
   npt.assert_equal(
       result, {
           'pixels': fake_emnist['pixels'].reshape([-1, 28 * 28]),
           'label': fake_emnist['label'],
           'binary_label': fake_emnist['label'] % 2,
           'sum_pixels': np.sum(fake_emnist['pixels'], axis=(1, 2))
       })
Exemplo n.º 3
0
 def test_different_preprocessors(self):
   with self.assertRaisesRegex(
       ValueError,
       'client_datasets should have the identical Preprocessor object'):
     list(
         client_datasets.buffered_shuffle_batch_client_datasets(
             [
                 client_datasets.ClientDataset(
                     {'x': np.arange(10, 20)},
                     client_datasets.BatchPreprocessor()),
                 client_datasets.ClientDataset(
                     {'x': np.arange(20, 30)},
                     client_datasets.BatchPreprocessor())
             ],
             batch_size=4,
             buffer_size=16,
             rng=np.random.RandomState(0)))
Exemplo n.º 4
0
 def test_all_examples(self):
   raw_examples = {'a': np.arange(3), 'b': np.arange(6).reshape([3, 2])}
   with self.subTest('no preprocessing'):
     npt.assert_equal(
         client_datasets.ClientDataset(raw_examples).all_examples(),
         raw_examples)
   with self.subTest('with preprocessing'):
     npt.assert_equal(
         client_datasets.ClientDataset(
             raw_examples,
             client_datasets.BatchPreprocessor([lambda x: {
                 'c': x['a'] + 1
             }])).all_examples(), {'c': [1, 2, 3]})
Exemplo n.º 5
0
 def test_padded_batch(self):
   d = client_datasets.ClientDataset(
       {
           'a': np.arange(5),
           'b': np.arange(10).reshape([5, 2])
       }, client_datasets.BatchPreprocessor([lambda x: {
           **x, 'a': 2 * x['a']
       }]))
   with self.subTest('1 bucket, kwargs'):
     view = d.padded_batch(batch_size=3)
     # `view` should be repeatedly iterable.
     for _ in range(2):
       batches = list(view)
       self.assertLen(batches, 2)
       npt.assert_equal(
           batches[0], {
               'a': [0, 2, 4],
               'b': [[0, 1], [2, 3], [4, 5]],
               '__mask__': [True, True, True],
           })
       npt.assert_equal(
           batches[1], {
               'a': [6, 8, 0],
               'b': [[6, 7], [8, 9], [0, 0]],
               '__mask__': [True, True, False]
           })
   with self.subTest('2 buckets, kwargs override'):
     view = d.padded_batch(
         client_datasets.PaddedBatchHParams(batch_size=4),
         num_batch_size_buckets=2)
     # `view` should be repeatedly iterable.
     for _ in range(2):
       batches = list(view)
       self.assertLen(batches, 2)
       npt.assert_equal(
           batches[0], {
               'a': [0, 2, 4, 6],
               'b': [[0, 1], [2, 3], [4, 5], [6, 7]],
               '__mask__': [True, True, True, True],
           })
       npt.assert_equal(batches[1], {
           'a': [8, 0],
           'b': [[8, 9], [0, 0]],
           '__mask__': [True, False]
       })
Exemplo n.º 6
0
  def test_slice(self):
    d = client_datasets.ClientDataset(
        {
            'a': np.arange(5),
            'b': np.arange(10).reshape([5, 2])
        }, client_datasets.BatchPreprocessor([lambda x: {
            **x, 'a': 2 * x['a']
        }]))

    with self.subTest('slice [:3]'):
      sliced = d[:3]
      batch = next(iter(sliced.batch(batch_size=3)))
      npt.assert_equal(batch, {'a': [0, 2, 4], 'b': [[0, 1], [2, 3], [4, 5]]})

    with self.subTest('slice [-3:]'):
      sliced = d[-3:]
      batch = next(iter(sliced.batch(batch_size=3)))
      npt.assert_equal(batch, {'a': [4, 6, 8], 'b': [[4, 5], [6, 7], [8, 9]]})
Exemplo n.º 7
0
 def test_preprocessor(self):
   batches = list(
       client_datasets.padded_batch_client_datasets([
           client_datasets.ClientDataset({'x': np.arange(6)},
                                         client_datasets.BatchPreprocessor(
                                             [lambda x: {
                                                 'x': x['x'] + 1
                                             }]))
       ],
                                                    batch_size=5))
   self.assertLen(batches, 2)
   npt.assert_equal(batches[0], {
       'x': np.arange(5) + 1,
       '__mask__': [True, True, True, True, True]
   })
   npt.assert_equal(batches[1], {
       'x': [6, 0, 0, 0, 0],
       '__mask__': [True, False, False, False, False]
   })
Exemplo n.º 8
0
 def test_preprocessor(self):
   batches = list(
       client_datasets.buffered_shuffle_batch_client_datasets(
           [
               client_datasets.ClientDataset({'x': np.arange(6)},
                                             client_datasets.BatchPreprocessor(
                                                 [lambda x: {
                                                     'x': x['x'] + 1
                                                 }]))
           ],
           batch_size=5,
           buffer_size=16,
           rng=np.random.RandomState(0)))
   self.assertLen(batches, 2)
   npt.assert_equal(batches[0], {
       'x': [6, 3, 2, 4, 1],
   })
   npt.assert_equal(batches[1], {
       'x': [5],
   })
Exemplo n.º 9
0
 def test_batch(self):
   d = client_datasets.ClientDataset(
       {
           'a': np.arange(5),
           'b': np.arange(10).reshape([5, 2])
       }, client_datasets.BatchPreprocessor([lambda x: {
           **x, 'a': 2 * x['a']
       }]))
   with self.subTest('keep remainder, kwargs'):
     view = d.batch(batch_size=3)
     # `view` should be repeatedly iterable.
     for _ in range(2):
       batches = list(view)
       self.assertLen(batches, 2)
       npt.assert_equal(batches[0], {
           'a': [0, 2, 4],
           'b': [[0, 1], [2, 3], [4, 5]]
       })
       npt.assert_equal(batches[1], {'a': [6, 8], 'b': [[6, 7], [8, 9]]})
   with self.subTest('drop remainder, hparams'):
     view = d.batch(
         client_datasets.BatchHParams(batch_size=3, drop_remainder=True))
     # `view` should be repeatedly iterable.
     for _ in range(2):
       batches = list(view)
       self.assertLen(batches, 1)
       npt.assert_equal(batches[0], {
           'a': [0, 2, 4],
           'b': [[0, 1], [2, 3], [4, 5]]
       })
   with self.subTest('no op drop remainder, hparams and kwargs'):
     view = d.batch(
         client_datasets.BatchHParams(batch_size=5), drop_remainder=True)
     # `view` should be repeatedly iterable.
     for _ in range(2):
       batches = list(view)
       self.assertLen(batches, 1)
       npt.assert_equal(batches[0], {
           'a': [0, 2, 4, 6, 8],
           'b': [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
       })
Exemplo n.º 10
0
  def test_preprocessor(self):
    preprocessor = client_datasets.BatchPreprocessor([
        # Flattens `pixels`.
        lambda x: {
            **x, 'pixels': x['pixels'].reshape([-1, 28 * 28])
        },
        # Introduce `binary_label`.
        lambda x: {
            **x, 'binary_label': x['label'] % 2
        },
    ])
    fake_emnist = {
        'pixels': np.random.uniform(size=(10, 28, 28)),
        'label': np.random.randint(10, size=(10,))
    }

    with self.subTest('2 step preprocessing'):
      result = preprocessor(fake_emnist)
      npt.assert_equal(
          result, {
              'pixels': fake_emnist['pixels'].reshape([-1, 28 * 28]),
              'label': fake_emnist['label'],
              'binary_label': fake_emnist['label'] % 2
          })
Exemplo n.º 11
0
  def test_shuffle_repeat_batch(self):
    d = client_datasets.ClientDataset(
        {
            'a': np.arange(5),
            'b': np.arange(10).reshape([5, 2])
        }, client_datasets.BatchPreprocessor([lambda x: {
            **x, 'a': 2 * x['a']
        }]))
    # Number of batches under different num_epochs/num_steps combinations.
    with self.subTest('repeating'):
      self.assertLen(list(d.shuffle_repeat_batch(batch_size=5)), 1)
      self.assertLen(list(d.shuffle_repeat_batch(batch_size=3)), 2)
      self.assertLen(list(d.shuffle_repeat_batch(batch_size=1)), 5)

      self.assertEmpty(
          list(d.shuffle_repeat_batch(batch_size=7, drop_remainder=True)))
      self.assertLen(
          list(d.shuffle_repeat_batch(batch_size=5, drop_remainder=True)), 1)
      self.assertLen(
          list(d.shuffle_repeat_batch(batch_size=3, drop_remainder=True)), 1)
      self.assertLen(
          list(d.shuffle_repeat_batch(batch_size=2, drop_remainder=True)), 2)
      self.assertLen(
          list(d.shuffle_repeat_batch(batch_size=1, drop_remainder=True)), 5)

      self.assertLen(
          list(
              d.shuffle_repeat_batch(
                  batch_size=5, num_epochs=None, num_steps=4)), 4)
      self.assertLen(
          list(
              d.shuffle_repeat_batch(
                  batch_size=3, num_epochs=None, num_steps=4)), 4)
      self.assertLen(
          list(
              d.shuffle_repeat_batch(
                  batch_size=1, num_epochs=None, num_steps=4)), 4)

      self.assertLen(
          list(
              d.shuffle_repeat_batch(
                  batch_size=5,
                  num_epochs=None,
                  num_steps=4,
                  drop_remainder=True)), 4)
      self.assertLen(
          list(
              d.shuffle_repeat_batch(
                  batch_size=3,
                  num_epochs=None,
                  num_steps=4,
                  drop_remainder=True)), 4)
      self.assertLen(
          list(
              d.shuffle_repeat_batch(
                  batch_size=1,
                  num_epochs=None,
                  num_steps=4,
                  drop_remainder=True)), 4)

      self.assertLen(list(d.shuffle_repeat_batch(batch_size=5, num_steps=4)), 1)
      self.assertLen(list(d.shuffle_repeat_batch(batch_size=3, num_steps=4)), 2)
      self.assertLen(list(d.shuffle_repeat_batch(batch_size=1, num_steps=4)), 4)

      self.assertLen(
          list(
              d.shuffle_repeat_batch(
                  batch_size=5, num_steps=4, drop_remainder=True)), 1)
      self.assertLen(
          list(
              d.shuffle_repeat_batch(
                  batch_size=3, num_steps=4, drop_remainder=True)), 1)
      self.assertLen(
          list(
              d.shuffle_repeat_batch(
                  batch_size=1, num_steps=4, drop_remainder=True)), 4)

      for drop_remainder in [False, True]:
        # 100 is as good as forever.
        self.assertLen(
            list(
                itertools.islice(
                    d.shuffle_repeat_batch(
                        batch_size=3,
                        num_epochs=None,
                        drop_remainder=drop_remainder), 100)), 100)

    # Check proper shuffling.
    with self.subTest('shuffling'):
      view = d.shuffle_repeat_batch(
          batch_size=3, num_epochs=None, num_steps=4, seed=1)
      # `view` should be repeatedly iterable.
      for _ in range(2):
        batches = list(view)
        self.assertLen(batches, 4)
        npt.assert_equal(batches[0], {
            'a': [4, 2, 8],
            'b': [[4, 5], [2, 3], [8, 9]]
        })
        npt.assert_equal(batches[1], {
            'a': [0, 6, 4],
            'b': [[0, 1], [6, 7], [4, 5]]
        })
        npt.assert_equal(batches[2], {
            'a': [8, 6, 0],
            'b': [[8, 9], [6, 7], [0, 1]]
        })
        npt.assert_equal(batches[3], {
            'a': [2, 6, 0],
            'b': [[2, 3], [6, 7], [0, 1]]
        })

    with self.subTest('skip shuffling'):
      view = d.shuffle_repeat_batch(batch_size=3, skip_shuffle=True)
      batches = list(view)
      self.assertLen(batches, 2)
      # Original order should be maintained and loop back to beginning for fill.
      npt.assert_equal(batches[0], {
          'a': [0, 2, 4],
          'b': [[0, 1], [2, 3], [4, 5]]
      })
      npt.assert_equal(batches[1], {
          'a': [6, 8, 0],
          'b': [[6, 7], [8, 9], [0, 1]]
      })