示例#1
0
    def test_federated_averaging(self):
        client_optimizer = optimizers.sgd(learning_rate=1.0)
        server_optimizer = optimizers.sgd(learning_rate=1.0)
        client_batch_hparams = client_datasets.ShuffleRepeatBatchHParams(
            batch_size=2, num_epochs=1, seed=0)
        algorithm = fed_avg.federated_averaging(grad_fn, client_optimizer,
                                                server_optimizer,
                                                client_batch_hparams)

        with self.subTest('init'):
            state = algorithm.init({'w': jnp.array([0., 2., 4.])})
            npt.assert_array_equal(state.params['w'], [0., 2., 4.])
            self.assertLen(state.opt_state, 2)

        with self.subTest('apply'):
            clients = [
                (b'cid0',
                 client_datasets.ClientDataset({'x': jnp.array([2., 4., 6.])}),
                 jax.random.PRNGKey(0)),
                (b'cid1',
                 client_datasets.ClientDataset({'x': jnp.array([8., 10.])}),
                 jax.random.PRNGKey(1)),
            ]
            state, client_diagnostics = algorithm.apply(state, clients)
            npt.assert_allclose(state.params['w'], [0., 1.5655555, 3.131111])
            npt.assert_allclose(client_diagnostics[b'cid0']['delta_l2_norm'],
                                1.4534444262)
            npt.assert_allclose(client_diagnostics[b'cid1']['delta_l2_norm'],
                                0.2484521282)
示例#2
0
    def test_fed_prox(self):
        client_optimizer = optimizers.sgd(learning_rate=1.0)
        server_optimizer = optimizers.sgd(learning_rate=1.0)
        client_batch_hparams = client_datasets.ShuffleRepeatBatchHParams(
            batch_size=2, num_epochs=1, seed=0)
        algorithm = fed_prox.fed_prox(per_example_loss,
                                      client_optimizer,
                                      server_optimizer,
                                      client_batch_hparams,
                                      proximal_weight=0.01)

        with self.subTest('init'):
            state = algorithm.init({'w': jnp.array(4.)})
            npt.assert_array_equal(state.params['w'], 4.)
            self.assertLen(state.opt_state, 2)

        with self.subTest('apply'):
            clients = [
                (b'cid0',
                 client_datasets.ClientDataset({'x': jnp.array([2., 4., 6.])}),
                 jax.random.PRNGKey(0)),
                (b'cid1',
                 client_datasets.ClientDataset({'x': jnp.array([8., 10.])}),
                 jax.random.PRNGKey(1)),
            ]
            state, client_diagnostics = algorithm.apply(state, clients)
            npt.assert_allclose(state.params['w'], -3.77)
            npt.assert_allclose(client_diagnostics[b'cid0']['delta_l2_norm'],
                                6.95)
            npt.assert_allclose(client_diagnostics[b'cid1']['delta_l2_norm'],
                                9.)
示例#3
0
 def test_multi(self):
   batches = list(
       client_datasets.padded_batch_client_datasets([
           client_datasets.ClientDataset({'x': np.arange(10)}),
           client_datasets.ClientDataset({'x': np.arange(10, 11)}),
           client_datasets.ClientDataset({'x': np.arange(11, 15)}),
           client_datasets.ClientDataset({'x': np.arange(15, 17)})
       ],
                                                    batch_size=4))
   self.assertLen(batches, 5)
   npt.assert_equal(batches[0], {
       'x': [0, 1, 2, 3],
       '__mask__': [True, True, True, True]
   })
   npt.assert_equal(batches[1], {
       'x': [4, 5, 6, 7],
       '__mask__': [True, True, True, True]
   })
   npt.assert_equal(batches[2], {
       'x': [8, 9, 10, 11],
       '__mask__': [True, True, True, True]
   })
   npt.assert_equal(batches[3], {
       'x': [12, 13, 14, 15],
       '__mask__': [True, True, True, True]
   })
   npt.assert_equal(batches[4], {
       'x': [16, 0, 0, 0],
       '__mask__': [True, False, False, False]
   })
示例#4
0
 def test_multi(self):
   batches = list(
       client_datasets.buffered_shuffle_batch_client_datasets(
           [
               client_datasets.ClientDataset({'x': np.arange(10)}),
               client_datasets.ClientDataset({'x': np.arange(10, 11)}),
               client_datasets.ClientDataset({'x': np.arange(11, 15)}),
               client_datasets.ClientDataset({'x': np.arange(15, 17)})
           ],
           batch_size=4,
           buffer_size=16,
           rng=np.random.RandomState(0)))
   self.assertLen(batches, 5)
   npt.assert_equal(batches[0], {
       'x': [1, 6, 16, 8],
   })
   npt.assert_equal(batches[1], {
       'x': [9, 13, 4, 2],
   })
   npt.assert_equal(batches[2], {
       'x': [14, 10, 7, 15],
   })
   npt.assert_equal(batches[3], {
       'x': [11, 3, 0, 5],
   })
   npt.assert_equal(batches[4], {
       'x': [12],
   })
示例#5
0
    def test_client_delta_clip_norm(self):
        base_optimizer = optimizers.sgd(learning_rate=1.0)
        train_batch_hparams = client_datasets.ShuffleRepeatBatchHParams(
            batch_size=2, num_epochs=1, seed=0)
        grad_batch_hparams = client_datasets.PaddedBatchHParams(batch_size=2)
        server_learning_rate = 0.2
        algorithm = mime_lite.mime_lite(per_example_loss,
                                        base_optimizer,
                                        train_batch_hparams,
                                        grad_batch_hparams,
                                        server_learning_rate,
                                        client_delta_clip_norm=0.5)

        clients = [
            (b'cid0',
             client_datasets.ClientDataset({'x': jnp.array([0.2, 0.4, 0.6])}),
             jax.random.PRNGKey(0)),
            (b'cid1',
             client_datasets.ClientDataset({'x': jnp.array([0.8, 0.1])}),
             jax.random.PRNGKey(1)),
        ]
        state = algorithm.init({'w': jnp.array(4.)})
        state, client_diagnostics = algorithm.apply(state, clients)
        npt.assert_allclose(state.params['w'], 3.904)
        npt.assert_allclose(
            client_diagnostics[b'cid0']['clipped_delta_l2_norm'], 0.5)
        npt.assert_allclose(
            client_diagnostics[b'cid1']['clipped_delta_l2_norm'], 0.45000005)
示例#6
0
  def test_mime(self):
    base_optimizer = optimizers.sgd(learning_rate=1.0)
    train_batch_hparams = client_datasets.ShuffleRepeatBatchHParams(
        batch_size=2, num_epochs=1, seed=0)
    grad_batch_hparams = client_datasets.PaddedBatchHParams(batch_size=2)
    server_learning_rate = 0.2
    algorithm = mime.mime(per_example_loss, base_optimizer, train_batch_hparams,
                          grad_batch_hparams, server_learning_rate)

    with self.subTest('init'):
      state = algorithm.init({'w': jnp.array(4.)})
      npt.assert_equal(state.params, {'w': jnp.array(4.)})
      self.assertLen(state.opt_state, 2)

    with self.subTest('apply'):
      clients = [
          (b'cid0',
           client_datasets.ClientDataset({'x': jnp.array([2., 4., 6.])}),
           jax.random.PRNGKey(0)),
          (b'cid1', client_datasets.ClientDataset({'x': jnp.array([8., 10.])}),
           jax.random.PRNGKey(1)),
      ]
      state, client_diagnostics = algorithm.apply(state, clients)
      npt.assert_allclose(state.params['w'], 2.08)
      npt.assert_allclose(client_diagnostics[b'cid0']['delta_l2_norm'], 12.)
      npt.assert_allclose(client_diagnostics[b'cid1']['delta_l2_norm'], 6.)
示例#7
0
 def test_different_features(self):
   with self.assertRaisesRegex(
       ValueError, 'client_datasets should have identical features'):
     list(
         client_datasets.padded_batch_client_datasets([
             client_datasets.ClientDataset({'x': np.arange(10)}),
             client_datasets.ClientDataset({'y': np.arange(10, 11)})
         ],
                                                      batch_size=4))
示例#8
0
 def test_different_preprocessors(self):
   with self.assertRaisesRegex(
       ValueError,
       'client_datasets should have the identical Preprocessor object'):
     list(
         client_datasets.padded_batch_client_datasets([
             client_datasets.ClientDataset(
                 {'x': np.arange(10)}, client_datasets.BatchPreprocessor()),
             client_datasets.ClientDataset({'x': np.arange(10, 11)},
                                           client_datasets.BatchPreprocessor())
         ],
                                                      batch_size=4))
示例#9
0
 def test_different_features(self):
   with self.assertRaisesRegex(
       ValueError, 'client_datasets should have identical features'):
     list(
         client_datasets.buffered_shuffle_batch_client_datasets(
             [
                 client_datasets.ClientDataset({'x': np.arange(10)}),
                 client_datasets.ClientDataset({'y': np.arange(10, 11)})
             ],
             batch_size=4,
             buffer_size=16,
             rng=np.random.RandomState(0)))
示例#10
0
 def test_all_examples(self):
   raw_examples = {'a': np.arange(3), 'b': np.arange(6).reshape([3, 2])}
   with self.subTest('no preprocessing'):
     npt.assert_equal(
         client_datasets.ClientDataset(raw_examples).all_examples(),
         raw_examples)
   with self.subTest('with preprocessing'):
     npt.assert_equal(
         client_datasets.ClientDataset(
             raw_examples,
             client_datasets.BatchPreprocessor([lambda x: {
                 'c': x['a'] + 1
             }])).all_examples(), {'c': [1, 2, 3]})
示例#11
0
 def test_different_preprocessors(self):
   with self.assertRaisesRegex(
       ValueError,
       'client_datasets should have the identical Preprocessor object'):
     list(
         client_datasets.buffered_shuffle_batch_client_datasets(
             [
                 client_datasets.ClientDataset(
                     {'x': np.arange(10, 20)},
                     client_datasets.BatchPreprocessor()),
                 client_datasets.ClientDataset(
                     {'x': np.arange(20, 30)},
                     client_datasets.BatchPreprocessor())
             ],
             batch_size=4,
             buffer_size=16,
             rng=np.random.RandomState(0)))
示例#12
0
 def sample(self):
   client_id = self._round_num + self._base
   dataset = client_datasets.ClientDataset({
       'x': np.array([client_id], dtype=np.int32),
       'y': np.array([client_id], dtype=np.int32) % 2
   })
   rng = None
   self._round_num += 1
   return [(client_id, dataset, rng)]
示例#13
0
 def test_single_buffer_1(self):
   batches = list(
       client_datasets.buffered_shuffle_batch_client_datasets(
           [client_datasets.ClientDataset({'x': np.arange(6)})],
           batch_size=5,
           buffer_size=1,
           rng=np.random.RandomState(0)))
   self.assertLen(batches, 2)
   # No shuffling.
   npt.assert_equal(batches[0], {'x': np.arange(5)})
   npt.assert_equal(batches[1], {'x': [5]})
示例#14
0
    def test_uniform_get_client_sampler(self):
        num_clients = 2
        round_num = 3
        client_to_data_mapping = {i: {'x': np.arange(i)} for i in range(100)}
        fd = in_memory_federated_data.InMemoryFederatedData(
            client_to_data_mapping)
        client_sampler = client_samplers.UniformGetClientSampler(
            fd, num_clients, seed=0, start_round_num=round_num)
        with self.subTest('sample'):
            client_rngs = jax.random.split(jax.random.PRNGKey(round_num),
                                           num_clients)
            expect = [(78, client_datasets.ClientDataset({'x': np.arange(78)}),
                       client_rngs[0]),
                      (56, client_datasets.ClientDataset({'x': np.arange(56)}),
                       client_rngs[1])]
            self.assert_clients_equal(client_sampler.sample(), expect)

        with self.subTest('set_round_num'):
            self.assertNotEqual(client_sampler._round_num, round_num)
            client_sampler.set_round_num(round_num)
            self.assert_clients_equal(client_sampler.sample(), expect)
示例#15
0
 def __next__(
     self) -> Tuple[federated_data.ClientId, client_datasets.ClientDataset]:
   client_ids_result = self._client_ids_cursor.fetchone()
   if client_ids_result is None:
     raise StopIteration
   client_id = client_ids_result[0]
   examples_cursor = self._connection.execute(
       'SELECT serialized_example_proto FROM examples WHERE split_name = ? AND client_id = ? ORDER BY rowid;',
       [self._split_name, client_id])
   examples = [r[0] for r in examples_cursor.fetchall()]
   return client_id, client_datasets.ClientDataset(
       self._parse_examples(examples))
示例#16
0
    def test_expectation_step(self):
        def per_example_loss(params, batch, rng):
            self.assertIsNotNone(rng)
            return jnp.square(params - batch['x'])

        trainer = hyp_cluster.ClientDeltaTrainer(models.grad(per_example_loss),
                                                 optimizers.sgd(0.5))
        batch_hparams = client_datasets.ShuffleRepeatBatchHParams(batch_size=1,
                                                                  num_epochs=5)

        cluster_params = [jnp.array(1.), jnp.array(-1.), jnp.array(3.14)]
        client_cluster_ids = {b'0': 0, b'1': 0, b'2': 1, b'3': 1, b'4': 0}
        # RNGs are not actually used.
        clients = [
            (b'0', client_datasets.ClientDataset({'x': np.array([1.1])}),
             jax.random.PRNGKey(0)),
            (b'1', client_datasets.ClientDataset({'x': np.array([0.9, 0.9])}),
             jax.random.PRNGKey(1)),
            (b'2', client_datasets.ClientDataset({'x': np.array([-1.1])}),
             jax.random.PRNGKey(2)),
            (b'3',
             client_datasets.ClientDataset({'x': np.array([-0.9, -0.9,
                                                           -0.9])}),
             jax.random.PRNGKey(3)),
            (b'4', client_datasets.ClientDataset({'x': np.array([-0.1])}),
             jax.random.PRNGKey(4)),
        ]
        cluster_delta_params = hyp_cluster.expectation_step(
            trainer=trainer,
            cluster_params=cluster_params,
            client_cluster_ids=client_cluster_ids,
            clients=clients,
            batch_hparams=batch_hparams)
        self.assertIsInstance(cluster_delta_params, list)
        self.assertLen(cluster_delta_params, 3)
        npt.assert_allclose(cluster_delta_params[0],
                            (-0.1 + 0.1 * 2 + 1.1) / 4)
        npt.assert_allclose(cluster_delta_params[1], (0.1 - 0.1 * 3) / 4,
                            rtol=1e-6)
        self.assertIsNone(cluster_delta_params[2])
示例#17
0
 def test_single_default_buckets(self):
   batches = list(
       client_datasets.padded_batch_client_datasets(
           [client_datasets.ClientDataset({'x': np.arange(6)})], batch_size=5))
   self.assertLen(batches, 2)
   npt.assert_equal(batches[0], {
       'x': np.arange(5),
       '__mask__': [True, True, True, True, True]
   })
   npt.assert_equal(batches[1], {
       'x': [5, 0, 0, 0, 0],
       '__mask__': [True, False, False, False, False]
   })
示例#18
0
 def test_single_buffer_4(self):
   batches = list(
       client_datasets.buffered_shuffle_batch_client_datasets(
           [client_datasets.ClientDataset({'x': np.arange(8)})],
           batch_size=6,
           buffer_size=4,
           rng=np.random.RandomState(0)))
   self.assertLen(batches, 2)
   npt.assert_equal(batches[0], {
       'x': [2, 4, 5, 6, 7, 3],
   })
   npt.assert_equal(batches[1], {
       'x': [1, 0],
   })
示例#19
0
 def test_single_has_buckets(self):
   batches = list(
       client_datasets.padded_batch_client_datasets(
           [client_datasets.ClientDataset({'x': np.arange(8)})],
           batch_size=6,
           num_batch_size_buckets=4))
   self.assertLen(batches, 2)
   npt.assert_equal(batches[0], {
       'x': np.arange(6),
       '__mask__': [True, True, True, True, True, True]
   })
   npt.assert_equal(batches[1], {
       'x': [6, 7, 0],
       '__mask__': [True, True, False]
   })
示例#20
0
def bench_client_dataset(preprocess, mode, batch_size=128, num_steps=100):
    """Benchmarks ClientDataset."""
    preprocessor = client_datasets.NoOpBatchPreprocessor
    if preprocess:
        preprocessor = preprocessor.append(f)
    dataset = client_datasets.ClientDataset(FAKE_MNIST, preprocessor)
    if mode == 'train':
        batches = dataset.shuffle_repeat_batch(batch_size=batch_size,
                                               num_steps=num_steps)
    else:
        batches = dataset.padded_batch(batch_size=batch_size,
                                       num_batch_size_buckets=4)
    n = 0
    for _ in batches:
        n += 1
    return n
示例#21
0
 def test_padded_batch(self):
   d = client_datasets.ClientDataset(
       {
           'a': np.arange(5),
           'b': np.arange(10).reshape([5, 2])
       }, client_datasets.BatchPreprocessor([lambda x: {
           **x, 'a': 2 * x['a']
       }]))
   with self.subTest('1 bucket, kwargs'):
     view = d.padded_batch(batch_size=3)
     # `view` should be repeatedly iterable.
     for _ in range(2):
       batches = list(view)
       self.assertLen(batches, 2)
       npt.assert_equal(
           batches[0], {
               'a': [0, 2, 4],
               'b': [[0, 1], [2, 3], [4, 5]],
               '__mask__': [True, True, True],
           })
       npt.assert_equal(
           batches[1], {
               'a': [6, 8, 0],
               'b': [[6, 7], [8, 9], [0, 0]],
               '__mask__': [True, True, False]
           })
   with self.subTest('2 buckets, kwargs override'):
     view = d.padded_batch(
         client_datasets.PaddedBatchHParams(batch_size=4),
         num_batch_size_buckets=2)
     # `view` should be repeatedly iterable.
     for _ in range(2):
       batches = list(view)
       self.assertLen(batches, 2)
       npt.assert_equal(
           batches[0], {
               'a': [0, 2, 4, 6],
               'b': [[0, 1], [2, 3], [4, 5], [6, 7]],
               '__mask__': [True, True, True, True],
           })
       npt.assert_equal(batches[1], {
           'a': [8, 0],
           'b': [[8, 9], [0, 0]],
           '__mask__': [True, False]
       })
示例#22
0
  def test_slice(self):
    d = client_datasets.ClientDataset(
        {
            'a': np.arange(5),
            'b': np.arange(10).reshape([5, 2])
        }, client_datasets.BatchPreprocessor([lambda x: {
            **x, 'a': 2 * x['a']
        }]))

    with self.subTest('slice [:3]'):
      sliced = d[:3]
      batch = next(iter(sliced.batch(batch_size=3)))
      npt.assert_equal(batch, {'a': [0, 2, 4], 'b': [[0, 1], [2, 3], [4, 5]]})

    with self.subTest('slice [-3:]'):
      sliced = d[-3:]
      batch = next(iter(sliced.batch(batch_size=3)))
      npt.assert_equal(batch, {'a': [4, 6, 8], 'b': [[4, 5], [6, 7], [8, 9]]})
示例#23
0
 def test_preprocessor(self):
   batches = list(
       client_datasets.padded_batch_client_datasets([
           client_datasets.ClientDataset({'x': np.arange(6)},
                                         client_datasets.BatchPreprocessor(
                                             [lambda x: {
                                                 'x': x['x'] + 1
                                             }]))
       ],
                                                    batch_size=5))
   self.assertLen(batches, 2)
   npt.assert_equal(batches[0], {
       'x': np.arange(5) + 1,
       '__mask__': [True, True, True, True, True]
   })
   npt.assert_equal(batches[1], {
       'x': [6, 0, 0, 0, 0],
       '__mask__': [True, False, False, False, False]
   })
示例#24
0
 def test_preprocessor(self):
   batches = list(
       client_datasets.buffered_shuffle_batch_client_datasets(
           [
               client_datasets.ClientDataset({'x': np.arange(6)},
                                             client_datasets.BatchPreprocessor(
                                                 [lambda x: {
                                                     'x': x['x'] + 1
                                                 }]))
           ],
           batch_size=5,
           buffer_size=16,
           rng=np.random.RandomState(0)))
   self.assertLen(batches, 2)
   npt.assert_equal(batches[0], {
       'x': [6, 3, 2, 4, 1],
   })
   npt.assert_equal(batches[1], {
       'x': [5],
   })
示例#25
0
 def test_batch(self):
   d = client_datasets.ClientDataset(
       {
           'a': np.arange(5),
           'b': np.arange(10).reshape([5, 2])
       }, client_datasets.BatchPreprocessor([lambda x: {
           **x, 'a': 2 * x['a']
       }]))
   with self.subTest('keep remainder, kwargs'):
     view = d.batch(batch_size=3)
     # `view` should be repeatedly iterable.
     for _ in range(2):
       batches = list(view)
       self.assertLen(batches, 2)
       npt.assert_equal(batches[0], {
           'a': [0, 2, 4],
           'b': [[0, 1], [2, 3], [4, 5]]
       })
       npt.assert_equal(batches[1], {'a': [6, 8], 'b': [[6, 7], [8, 9]]})
   with self.subTest('drop remainder, hparams'):
     view = d.batch(
         client_datasets.BatchHParams(batch_size=3, drop_remainder=True))
     # `view` should be repeatedly iterable.
     for _ in range(2):
       batches = list(view)
       self.assertLen(batches, 1)
       npt.assert_equal(batches[0], {
           'a': [0, 2, 4],
           'b': [[0, 1], [2, 3], [4, 5]]
       })
   with self.subTest('no op drop remainder, hparams and kwargs'):
     view = d.batch(
         client_datasets.BatchHParams(batch_size=5), drop_remainder=True)
     # `view` should be repeatedly iterable.
     for _ in range(2):
       batches = list(view)
       self.assertLen(batches, 1)
       npt.assert_equal(batches[0], {
           'a': [0, 2, 4, 6, 8],
           'b': [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
       })
示例#26
0
    def test_agnostic_federated_averaging(self):
        algorithm = agnostic_fed_avg.agnostic_federated_averaging(
            per_example_loss=per_example_loss,
            client_optimizer=optimizers.sgd(learning_rate=1.0),
            server_optimizer=optimizers.sgd(learning_rate=0.1),
            client_batch_hparams=client_datasets.ShuffleRepeatBatchHParams(
                batch_size=3, num_epochs=1, seed=0),
            domain_batch_hparams=client_datasets.PaddedBatchHParams(
                batch_size=3),
            init_domain_weights=[0.1, 0.2, 0.3, 0.4],
            domain_learning_rate=0.01,
            domain_algorithm='eg',
            domain_window_size=2,
            init_domain_window=[1., 2., 3., 4.])

        with self.subTest('init'):
            state = algorithm.init({'w': jnp.array(4.)})
            npt.assert_equal(state.params, {'w': jnp.array(4.)})
            self.assertLen(state.opt_state, 2)
            npt.assert_allclose(state.domain_weights, [0.1, 0.2, 0.3, 0.4])
            npt.assert_allclose(state.domain_window,
                                [[1., 2., 3., 4.], [1., 2., 3., 4.]])

        with self.subTest('apply'):
            clients = [
                (b'cid0',
                 client_datasets.ClientDataset({
                     'x':
                     jnp.array([1., 2., 4., 3., 6., 1.]),
                     'domain_id':
                     jnp.array([1, 0, 0, 0, 2, 2])
                 }), jax.random.PRNGKey(0)),
                (b'cid1',
                 client_datasets.ClientDataset({
                     'x':
                     jnp.array([8., 10., 5.]),
                     'domain_id':
                     jnp.array([1, 3, 1])
                 }), jax.random.PRNGKey(1)),
            ]
            next_state, client_diagnostics = algorithm.apply(state, clients)
            npt.assert_allclose(next_state.params['w'], 3.5555556)
            npt.assert_allclose(
                next_state.domain_weights,
                [0.08702461, 0.18604803, 0.2663479, 0.46057943])
            npt.assert_allclose(next_state.domain_window,
                                [[1., 2., 3., 4.], [3., 3., 2., 1.]])
            npt.assert_allclose(client_diagnostics[b'cid0']['delta_l2_norm'],
                                2.8333335)
            npt.assert_allclose(client_diagnostics[b'cid1']['delta_l2_norm'],
                                7.666667)

        with self.subTest('invalid init_domain_weights'):
            with self.assertRaisesRegex(
                    ValueError,
                    'init_domain_weights must sum to approximately 1.'):
                agnostic_fed_avg.agnostic_federated_averaging(
                    per_example_loss=per_example_loss,
                    client_optimizer=optimizers.sgd(learning_rate=1.0),
                    server_optimizer=optimizers.sgd(learning_rate=1.0),
                    client_batch_hparams=client_datasets.
                    ShuffleRepeatBatchHParams(batch_size=3),
                    domain_batch_hparams=client_datasets.PaddedBatchHParams(
                        batch_size=3),
                    init_domain_weights=[50., 0., 0., 0.],
                    domain_learning_rate=0.5)

        with self.subTest('unequal lengths'):
            with self.assertRaisesRegex(
                    ValueError,
                    'init_domain_weights and init_domain_window must be equal lengths.'
            ):
                agnostic_fed_avg.agnostic_federated_averaging(
                    per_example_loss=per_example_loss,
                    client_optimizer=optimizers.sgd(learning_rate=1.0),
                    server_optimizer=optimizers.sgd(learning_rate=1.0),
                    client_batch_hparams=client_datasets.
                    ShuffleRepeatBatchHParams(batch_size=3),
                    domain_batch_hparams=client_datasets.PaddedBatchHParams(
                        batch_size=3),
                    init_domain_weights=[0.1, 0.2, 0.3, 0.4],
                    domain_learning_rate=0.5,
                    init_domain_window=[1, 2])
示例#27
0
    def test_hyp_cluster(self):
        functions_called = set()

        def per_example_loss(params, batch, rng):
            self.assertIsNotNone(rng)
            functions_called.add('per_example_loss')
            return jnp.square(params - batch['x'])

        def regularizer(params):
            del params
            functions_called.add('regularizer')
            return 0

        client_optimizer = optimizers.sgd(0.5)
        server_optimizer = optimizers.sgd(0.25)

        maximization_batch_hparams = client_datasets.PaddedBatchHParams(
            batch_size=2)
        expectation_batch_hparams = client_datasets.ShuffleRepeatBatchHParams(
            batch_size=1, num_epochs=5)

        algorithm = hyp_cluster.hyp_cluster(
            per_example_loss=per_example_loss,
            client_optimizer=client_optimizer,
            server_optimizer=server_optimizer,
            maximization_batch_hparams=maximization_batch_hparams,
            expectation_batch_hparams=expectation_batch_hparams,
            regularizer=regularizer)

        init_state = algorithm.init([jnp.array(1.), jnp.array(-1.)])
        # Nothing happens with empty data.
        no_op_state, diagnostics = algorithm.apply(init_state, clients=[])
        npt.assert_array_equal(init_state.cluster_params,
                               no_op_state.cluster_params)
        self.assertEmpty(diagnostics)
        # Some actual training. PRNGKeys are not actually used.
        clients = [
            (b'0', client_datasets.ClientDataset({'x': np.array([1.1])}),
             jax.random.PRNGKey(0)),
            (b'1', client_datasets.ClientDataset({'x': np.array([0.9, 0.9])}),
             jax.random.PRNGKey(1)),
            (b'2', client_datasets.ClientDataset({'x': np.array([-1.1])}),
             jax.random.PRNGKey(2)),
            (b'3',
             client_datasets.ClientDataset({'x': np.array([-0.9, -0.9,
                                                           -0.9])}),
             jax.random.PRNGKey(3)),
        ]
        next_state, diagnostics = algorithm.apply(init_state, clients)
        npt.assert_equal(
            diagnostics, {
                b'0': {
                    'cluster_id': 0
                },
                b'1': {
                    'cluster_id': 0
                },
                b'2': {
                    'cluster_id': 1
                },
                b'3': {
                    'cluster_id': 1
                },
            })
        cluster_params = next_state.cluster_params
        self.assertIsInstance(cluster_params, list)
        self.assertLen(cluster_params, 2)
        npt.assert_allclose(cluster_params[0], [1. - 0.25 * 0.1 / 3])
        npt.assert_allclose(cluster_params[1], [-1. + 0.25 * 0.2 / 4])
        self.assertCountEqual(functions_called,
                              ['per_example_loss', 'regularizer'])
示例#28
0
    def test_hyp_cluster_evaluator(self):
        functions_called = set()

        def apply_for_eval(params, batch):
            functions_called.add('apply_for_eval')
            score = params * batch['x']
            return jnp.stack([-score, score], axis=-1)

        def apply_for_train(params, batch, rng):
            functions_called.add('apply_for_train')
            self.assertIsNotNone(rng)
            return params * batch['x']

        def train_loss(batch, out):
            functions_called.add('train_loss')
            return jnp.abs(batch['y'] * 2 - 1 - out)

        def regularizer(params):
            # Just to check regularizer is called.
            del params
            functions_called.add('regularizer')
            return 0

        evaluator = hyp_cluster.HypClusterEvaluator(
            models.Model(init=None,
                         apply_for_eval=apply_for_eval,
                         apply_for_train=apply_for_train,
                         train_loss=train_loss,
                         eval_metrics={'accuracy': metrics.Accuracy()}),
            regularizer)

        cluster_params = [jnp.array(1.), jnp.array(-1.)]
        train_clients = [
            # Evaluated using cluster 0.
            (b'0',
             client_datasets.ClientDataset({
                 'x': np.array([3., 2., 1.]),
                 'y': np.array([1, 1, 0])
             }), jax.random.PRNGKey(0)),
            # Evaluated using cluster 1.
            (b'1',
             client_datasets.ClientDataset({
                 'x':
                 np.array([0.9, -0.9, 0.8, -0.8, -0.3]),
                 'y':
                 np.array([0, 1, 0, 1, 0])
             }), jax.random.PRNGKey(1)),
        ]
        # Test clients are generated from train_clients by swapping client ids and
        # then flipping labels.
        test_clients = [
            # Evaluated using cluster 0.
            (b'0',
             client_datasets.ClientDataset({
                 'x':
                 np.array([0.9, -0.9, 0.8, -0.8, -0.3]),
                 'y':
                 np.array([1, 0, 1, 0, 1])
             })),
            # Evaluated using cluster 1.
            (b'1',
             client_datasets.ClientDataset({
                 'x': np.array([3., 2., 1.]),
                 'y': np.array([0, 0, 1])
             })),
        ]
        for batch_size in [1, 2, 4]:
            with self.subTest(f'batch_size = {batch_size}'):
                batch_hparams = client_datasets.PaddedBatchHParams(
                    batch_size=batch_size)
                metric_values = dict(
                    evaluator.evaluate_clients(cluster_params=cluster_params,
                                               train_clients=train_clients,
                                               test_clients=test_clients,
                                               batch_hparams=batch_hparams))
                self.assertCountEqual(metric_values, [b'0', b'1'])
                self.assertCountEqual(metric_values[b'0'], ['accuracy'])
                npt.assert_allclose(metric_values[b'0']['accuracy'], 4 / 5)
                self.assertCountEqual(metric_values[b'1'], ['accuracy'])
                npt.assert_allclose(metric_values[b'1']['accuracy'], 2 / 3)
        self.assertCountEqual(
            functions_called,
            ['apply_for_train', 'train_loss', 'apply_for_eval', 'regularizer'])
示例#29
0
    def test_maximization_step(self):
        # L1 distance from centers.
        def per_example_loss(params, batch, rng):
            # Randomly flip the center to test rng behavior.
            sign = jax.random.bernoulli(rng) * 2 - 1
            return jnp.abs(params * sign - batch['x'])

        def regularizer(params):
            return 0.01 * jnp.sum(jnp.abs(params))

        evaluator = models.AverageLossEvaluator(per_example_loss, regularizer)
        cluster_params = [jnp.array(0.), jnp.array(-1.), jnp.array(2.)]
        # Batch size is chosen so that we run 1 or 2 batches.
        batch_hparams = client_datasets.PaddedBatchHParams(batch_size=2)
        # Special seeds:
        # - No flip in first 2 steps for all 3 clusters: 0;
        # - Flip all in first 2 steps for all 3 clusters: 16;
        # - No flip then flip all for all 3 clusters: 68;
        # - Flips only cluster 1 in first 2 steps: 106.
        clients = [
            # No flip in first 2 steps for all 3 clusters.
            (b'near0', client_datasets.ClientDataset({'x': np.array([0.1])}),
             jax.random.PRNGKey(0)),
            # Flip all in first 2 steps for all 3 clusters.
            (b'near-1',
             client_datasets.ClientDataset({'x': np.array([0.9, 1.1, 1.3])}),
             jax.random.PRNGKey(16)),
            # No flip then flip all for all 3 clusters.
            (b'near2',
             client_datasets.ClientDataset({'x': np.array([1.9, 2.1, -2.1])}),
             jax.random.PRNGKey(68)),
            # Flips only cluster 1 in first 2 steps.
            (b'near1',
             client_datasets.ClientDataset({'x': np.array([0.9, 1.1, 1.3])}),
             jax.random.PRNGKey(106)),
        ]

        cluster_losses = hyp_cluster._cluster_losses(
            evaluator=evaluator,
            cluster_params=cluster_params,
            clients=clients,
            batch_hparams=batch_hparams)
        self.assertCountEqual(cluster_losses,
                              [b'near0', b'near-1', b'near2', b'near1'])
        npt.assert_allclose(cluster_losses[b'near0'],
                            np.array([0.1, 1.1 + 0.01, 1.9 + 0.02]))
        npt.assert_allclose(cluster_losses[b'near-1'],
                            np.array([1.1, 0.5 / 3 + 0.01, 3.1 + 0.02]))
        npt.assert_allclose(cluster_losses[b'near2'],
                            np.array([6.1 / 3, 9.1 / 3 + 0.01, 0.1 + 0.02]),
                            rtol=1e-6)
        npt.assert_allclose(cluster_losses[b'near1'],
                            np.array([1.1, 0.5 / 3 + 0.01, 0.9 + 0.02]))

        client_cluster_ids = hyp_cluster.maximization_step(
            evaluator=evaluator,
            cluster_params=cluster_params,
            clients=clients,
            batch_hparams=batch_hparams)
        self.assertDictEqual(client_cluster_ids, {
            b'near0': 0,
            b'near-1': 1,
            b'near2': 2,
            b'near1': 1
        })
示例#30
0
    def test_kmeans_init(self):
        functions_called = set()

        def init(rng):
            functions_called.add('init')
            return jax.random.uniform(rng)

        def apply_for_train(params, batch, rng):
            functions_called.add('apply_for_train')
            self.assertIsNotNone(rng)
            return params - batch['x']

        def train_loss(batch, out):
            functions_called.add('train_loss')
            return jnp.square(out) + batch['bias']

        def regularizer(params):
            del params
            functions_called.add('regularizer')
            return 0

        initializer = hyp_cluster.ModelKMeansInitializer(
            models.Model(init=init,
                         apply_for_train=apply_for_train,
                         apply_for_eval=None,
                         train_loss=train_loss,
                         eval_metrics={}), optimizers.sgd(0.5), regularizer)
        # Each client has 1 example, so it's very easy to reach minimal loss, at
        # which point the loss entirely depends on bias.
        clients = [
            (b'0',
             client_datasets.ClientDataset({
                 'x': np.array([1.01]),
                 'bias': np.array([-2.])
             }), jax.random.PRNGKey(1)),
            (b'1',
             client_datasets.ClientDataset({
                 'x': np.array([3.02]),
                 'bias': np.array([-1.])
             }), jax.random.PRNGKey(2)),
            (b'2',
             client_datasets.ClientDataset({
                 'x': np.array([3.03]),
                 'bias': np.array([1.])
             }), jax.random.PRNGKey(3)),
            (b'3',
             client_datasets.ClientDataset({
                 'x': np.array([1.04]),
                 'bias': np.array([2.])
             }), jax.random.PRNGKey(3)),
        ]
        train_batch_hparams = client_datasets.ShuffleRepeatBatchHParams(
            batch_size=1, num_epochs=5)
        eval_batch_hparams = client_datasets.PaddedBatchHParams(batch_size=2)
        # Using a rng that leads to b'0' being the initial center.
        cluster_params = initializer.cluster_params(
            num_clusters=3,
            rng=jax.random.PRNGKey(0),
            clients=clients,
            train_batch_hparams=train_batch_hparams,
            eval_batch_hparams=eval_batch_hparams)
        self.assertIsInstance(cluster_params, list)
        self.assertLen(cluster_params, 3)
        npt.assert_allclose(cluster_params, [1.01, 3.03, 1.04])
        self.assertCountEqual(
            functions_called,
            ['init', 'apply_for_train', 'train_loss', 'regularizer'])
        # Using a rng that leads to b'2' being the initial center.
        cluster_params = initializer.cluster_params(
            num_clusters=3,
            rng=jax.random.PRNGKey(1),
            clients=clients,
            train_batch_hparams=train_batch_hparams,
            eval_batch_hparams=eval_batch_hparams)
        self.assertIsInstance(cluster_params, list)
        self.assertLen(cluster_params, 3)
        npt.assert_allclose(cluster_params, [3.03, 1.04, 1.04])