コード例 #1
0
ファイル: fed_prox_test.py プロジェクト: google/fedjax
    def test_fed_prox(self):
        client_optimizer = optimizers.sgd(learning_rate=1.0)
        server_optimizer = optimizers.sgd(learning_rate=1.0)
        client_batch_hparams = client_datasets.ShuffleRepeatBatchHParams(
            batch_size=2, num_epochs=1, seed=0)
        algorithm = fed_prox.fed_prox(per_example_loss,
                                      client_optimizer,
                                      server_optimizer,
                                      client_batch_hparams,
                                      proximal_weight=0.01)

        with self.subTest('init'):
            state = algorithm.init({'w': jnp.array(4.)})
            npt.assert_array_equal(state.params['w'], 4.)
            self.assertLen(state.opt_state, 2)

        with self.subTest('apply'):
            clients = [
                (b'cid0',
                 client_datasets.ClientDataset({'x': jnp.array([2., 4., 6.])}),
                 jax.random.PRNGKey(0)),
                (b'cid1',
                 client_datasets.ClientDataset({'x': jnp.array([8., 10.])}),
                 jax.random.PRNGKey(1)),
            ]
            state, client_diagnostics = algorithm.apply(state, clients)
            npt.assert_allclose(state.params['w'], -3.77)
            npt.assert_allclose(client_diagnostics[b'cid0']['delta_l2_norm'],
                                6.95)
            npt.assert_allclose(client_diagnostics[b'cid1']['delta_l2_norm'],
                                9.)
コード例 #2
0
ファイル: fed_avg_test.py プロジェクト: google/fedjax
    def test_federated_averaging(self):
        client_optimizer = optimizers.sgd(learning_rate=1.0)
        server_optimizer = optimizers.sgd(learning_rate=1.0)
        client_batch_hparams = client_datasets.ShuffleRepeatBatchHParams(
            batch_size=2, num_epochs=1, seed=0)
        algorithm = fed_avg.federated_averaging(grad_fn, client_optimizer,
                                                server_optimizer,
                                                client_batch_hparams)

        with self.subTest('init'):
            state = algorithm.init({'w': jnp.array([0., 2., 4.])})
            npt.assert_array_equal(state.params['w'], [0., 2., 4.])
            self.assertLen(state.opt_state, 2)

        with self.subTest('apply'):
            clients = [
                (b'cid0',
                 client_datasets.ClientDataset({'x': jnp.array([2., 4., 6.])}),
                 jax.random.PRNGKey(0)),
                (b'cid1',
                 client_datasets.ClientDataset({'x': jnp.array([8., 10.])}),
                 jax.random.PRNGKey(1)),
            ]
            state, client_diagnostics = algorithm.apply(state, clients)
            npt.assert_allclose(state.params['w'], [0., 1.5655555, 3.131111])
            npt.assert_allclose(client_diagnostics[b'cid0']['delta_l2_norm'],
                                1.4534444262)
            npt.assert_allclose(client_diagnostics[b'cid1']['delta_l2_norm'],
                                0.2484521282)
コード例 #3
0
ファイル: mime_test.py プロジェクト: google/fedjax
 def test_create_train_for_each_client(self):
   base_optimizer = optimizers.sgd(learning_rate=1.0)
   train_for_each_client = mime.create_train_for_each_client(
       grad_fn, base_optimizer)
   batch_clients = [
       (b'cid0', [{
           'x': jnp.array([6., 4.])
       }, {
           'x': jnp.array([2., 2.])
       }], jax.random.PRNGKey(0)),
       (b'cid1', [{
           'x': jnp.array([10., 8])
       }], jax.random.PRNGKey(1)),
   ]
   server_params = {'w': jnp.array(4.)}
   server_opt_state = base_optimizer.init(server_params)
   server_grads = {'w': jnp.array(6.)}
   shared_input = {
       'params': server_params,
       'opt_state': server_opt_state,
       'control_variate': server_grads
   }
   client_outputs = dict(train_for_each_client(shared_input, batch_clients))
   npt.assert_allclose(client_outputs[b'cid0']['w'], 12.)
   npt.assert_allclose(client_outputs[b'cid1']['w'], 6.)
コード例 #4
0
ファイル: mime_test.py プロジェクト: google/fedjax
  def test_mime(self):
    base_optimizer = optimizers.sgd(learning_rate=1.0)
    train_batch_hparams = client_datasets.ShuffleRepeatBatchHParams(
        batch_size=2, num_epochs=1, seed=0)
    grad_batch_hparams = client_datasets.PaddedBatchHParams(batch_size=2)
    server_learning_rate = 0.2
    algorithm = mime.mime(per_example_loss, base_optimizer, train_batch_hparams,
                          grad_batch_hparams, server_learning_rate)

    with self.subTest('init'):
      state = algorithm.init({'w': jnp.array(4.)})
      npt.assert_equal(state.params, {'w': jnp.array(4.)})
      self.assertLen(state.opt_state, 2)

    with self.subTest('apply'):
      clients = [
          (b'cid0',
           client_datasets.ClientDataset({'x': jnp.array([2., 4., 6.])}),
           jax.random.PRNGKey(0)),
          (b'cid1', client_datasets.ClientDataset({'x': jnp.array([8., 10.])}),
           jax.random.PRNGKey(1)),
      ]
      state, client_diagnostics = algorithm.apply(state, clients)
      npt.assert_allclose(state.params['w'], 2.08)
      npt.assert_allclose(client_diagnostics[b'cid0']['delta_l2_norm'], 12.)
      npt.assert_allclose(client_diagnostics[b'cid1']['delta_l2_norm'], 6.)
コード例 #5
0
 def test_create_train_for_each_client(self):
     num_domains = 4
     shared_input = {
         'params': {
             'w': jnp.array(4.)
         },
         'alpha': jnp.array([0.1, 0.2, 0.3, 0.4])
     }
     batch_clients = [
         (b'cid0', [{
             'x': jnp.array([1., 2., 4.]),
             'domain_id': jnp.array([1, 0, 0]),
         }, {
             'x': jnp.array([3., 6., 1.]),
             'domain_id': jnp.array([0, 2, 2]),
         }], {
             'rng': jax.random.PRNGKey(0),
             'beta': jnp.array(0.5)
         }),
         (b'cid1', [{
             'x': jnp.array([8., 10., 5.]),
             'domain_id': jnp.array([1, 3, 1]),
         }], {
             'rng': jax.random.PRNGKey(1),
             'beta': jnp.array(0.2)
         }),
     ]
     client_optimizer = optimizers.sgd(learning_rate=1.0)
     func = agnostic_fed_avg.create_train_for_each_client(
         per_example_loss, client_optimizer, num_domains)
     client_delta_params = dict(func(shared_input, batch_clients))
     npt.assert_allclose(client_delta_params[b'cid0']['w'], 6.4)
     npt.assert_allclose(client_delta_params[b'cid1']['w'], 33.)
コード例 #6
0
    def test_client_delta_clip_norm(self):
        base_optimizer = optimizers.sgd(learning_rate=1.0)
        train_batch_hparams = client_datasets.ShuffleRepeatBatchHParams(
            batch_size=2, num_epochs=1, seed=0)
        grad_batch_hparams = client_datasets.PaddedBatchHParams(batch_size=2)
        server_learning_rate = 0.2
        algorithm = mime_lite.mime_lite(per_example_loss,
                                        base_optimizer,
                                        train_batch_hparams,
                                        grad_batch_hparams,
                                        server_learning_rate,
                                        client_delta_clip_norm=0.5)

        clients = [
            (b'cid0',
             client_datasets.ClientDataset({'x': jnp.array([0.2, 0.4, 0.6])}),
             jax.random.PRNGKey(0)),
            (b'cid1',
             client_datasets.ClientDataset({'x': jnp.array([0.8, 0.1])}),
             jax.random.PRNGKey(1)),
        ]
        state = algorithm.init({'w': jnp.array(4.)})
        state, client_diagnostics = algorithm.apply(state, clients)
        npt.assert_allclose(state.params['w'], 3.904)
        npt.assert_allclose(
            client_diagnostics[b'cid0']['clipped_delta_l2_norm'], 0.5)
        npt.assert_allclose(
            client_diagnostics[b'cid1']['clipped_delta_l2_norm'], 0.45000005)
コード例 #7
0
ファイル: optimizers_test.py プロジェクト: alshedivat/fedjax
  def test_ignore_grads_haiku(self):
    params = hk.data_structures.to_immutable_dict({
        'linear_1': {
            'w': jnp.array([1., 1., 1.])
        },
        'linear_2': {
            'w': jnp.array([2., 2., 2.]),
            'b': jnp.array([3., 3., 3.])
        }
    })
    grads = jax.tree_util.tree_map(lambda _: 0.5, params)
    ignore_optimizer = optimizers.ignore_grads_haiku(
        optimizer=optimizers.sgd(learning_rate=1.0),
        non_trainable_names=[('linear_1', 'w'), ('linear_2', 'b')])

    opt_state = ignore_optimizer.init(params)
    opt_state, updated_params = ignore_optimizer.apply(grads, opt_state, params)

    jax.tree_util.tree_multimap(
        npt.assert_array_equal, updated_params,
        hk.data_structures.to_immutable_dict({
            'linear_1': {
                'w': jnp.array([1., 1., 1.])
            },
            'linear_2': {
                'w': jnp.array([1.5, 1.5, 1.5]),
                'b': jnp.array([3., 3., 3.])
            }
        }))
コード例 #8
0
ファイル: fed_prox_test.py プロジェクト: google/fedjax
    def test_create_train_for_each_client(self):
        proximal_weight = 0.01

        def fed_prox_loss(params, server_params, batch, rng):
            example_loss = per_example_loss(params, batch, rng)
            proximal_loss = 0.5 * proximal_weight * tree_util.tree_l2_squared(
                jax.tree_util.tree_map(lambda a, b: a - b, server_params,
                                       params))
            return jnp.mean(example_loss + proximal_loss)

        grad_fn = jax.grad(fed_prox_loss)
        client_optimizer = optimizers.sgd(learning_rate=1.0)
        train_for_each_client = fed_prox.create_train_for_each_client(
            grad_fn, client_optimizer)
        batched_clients = [
            (b'cid0', [{
                'x': jnp.array([2., 4., 6.])
            }, {
                'x': jnp.array([8., 10., 12.])
            }], jax.random.PRNGKey(0)),
            (b'cid1', [{
                'x': jnp.array([1., 3., 5.])
            }, {
                'x': jnp.array([7., 9., 11.])
            }], jax.random.PRNGKey(1)),
        ]
        server_params = {'w': jnp.array(4.0)}
        client_outputs = dict(
            train_for_each_client(server_params, batched_clients))
        npt.assert_allclose(client_outputs[b'cid0']['w'], 13.96)
        npt.assert_allclose(client_outputs[b'cid1']['w'], 11.97)
コード例 #9
0
 def get(self) -> optimizers.Optimizer:
     """Gets the specified optimizer."""
     optimizer_name = self._get_flag('optimizer')
     learning_rate = self._get_flag('learning_rate')
     if optimizer_name == 'sgd':
         return optimizers.sgd(learning_rate)
     elif optimizer_name == 'momentum':
         return optimizers.sgd(learning_rate, self._get_flag('momentum'))
     elif optimizer_name == 'adam':
         return optimizers.adam(learning_rate, self._get_flag('adam_beta1'),
                                self._get_flag('adam_beta2'),
                                self._get_flag('adam_epsilon'))
     elif optimizer_name == 'rmsprop':
         return optimizers.rmsprop(learning_rate,
                                   self._get_flag('rmsprop_decay'),
                                   self._get_flag('rmsprop_epsilon'))
     elif optimizer_name == 'adagrad':
         return optimizers.adagrad(learning_rate,
                                   eps=self._get_flag('adagrad_epsilon'))
     else:
         raise ValueError(f'Unsupported optimizer {optimizer_name!r} from '
                          f'--{self._prefix}optimizer.')
コード例 #10
0
ファイル: fed_avg_test.py プロジェクト: google/fedjax
 def test_create_train_for_each_client(self):
     client_optimizer = optimizers.sgd(learning_rate=1.0)
     train_for_each_client = fed_avg.create_train_for_each_client(
         grad_fn, client_optimizer)
     batched_clients = [
         (b'cid0', [{
             'x': jnp.array([2., 4., 6.])
         }, {
             'x': jnp.array([8., 10., 12.])
         }], jax.random.PRNGKey(0)),
         (b'cid1', [{
             'x': jnp.array([1., 3., 5.])
         }, {
             'x': jnp.array([7., 9., 11.])
         }], jax.random.PRNGKey(1)),
     ]
     server_params = {'w': jnp.array(4.0)}
     client_outputs = dict(
         train_for_each_client(server_params, batched_clients))
     npt.assert_allclose(client_outputs[b'cid0']['w'], 0.45555544)
     npt.assert_allclose(client_outputs[b'cid1']['w'], 0.5761316)
コード例 #11
0
ファイル: hyp_cluster_test.py プロジェクト: google/fedjax
    def test_expectation_step(self):
        def per_example_loss(params, batch, rng):
            self.assertIsNotNone(rng)
            return jnp.square(params - batch['x'])

        trainer = hyp_cluster.ClientDeltaTrainer(models.grad(per_example_loss),
                                                 optimizers.sgd(0.5))
        batch_hparams = client_datasets.ShuffleRepeatBatchHParams(batch_size=1,
                                                                  num_epochs=5)

        cluster_params = [jnp.array(1.), jnp.array(-1.), jnp.array(3.14)]
        client_cluster_ids = {b'0': 0, b'1': 0, b'2': 1, b'3': 1, b'4': 0}
        # RNGs are not actually used.
        clients = [
            (b'0', client_datasets.ClientDataset({'x': np.array([1.1])}),
             jax.random.PRNGKey(0)),
            (b'1', client_datasets.ClientDataset({'x': np.array([0.9, 0.9])}),
             jax.random.PRNGKey(1)),
            (b'2', client_datasets.ClientDataset({'x': np.array([-1.1])}),
             jax.random.PRNGKey(2)),
            (b'3',
             client_datasets.ClientDataset({'x': np.array([-0.9, -0.9,
                                                           -0.9])}),
             jax.random.PRNGKey(3)),
            (b'4', client_datasets.ClientDataset({'x': np.array([-0.1])}),
             jax.random.PRNGKey(4)),
        ]
        cluster_delta_params = hyp_cluster.expectation_step(
            trainer=trainer,
            cluster_params=cluster_params,
            client_cluster_ids=client_cluster_ids,
            clients=clients,
            batch_hparams=batch_hparams)
        self.assertIsInstance(cluster_delta_params, list)
        self.assertLen(cluster_delta_params, 3)
        npt.assert_allclose(cluster_delta_params[0],
                            (-0.1 + 0.1 * 2 + 1.1) / 4)
        npt.assert_allclose(cluster_delta_params[1], (0.1 - 0.1 * 3) / 4,
                            rtol=1e-6)
        self.assertIsNone(cluster_delta_params[2])
コード例 #12
0
 def test_create_train_for_each_client(self):
     base_optimizer = optimizers.sgd(learning_rate=1.0)
     train_for_each_client = mime_lite.create_train_for_each_client(
         grad_fn, base_optimizer)
     batch_clients = [
         (b'cid0', [{
             'x': jnp.array([0.6, 0.4])
         }, {
             'x': jnp.array([0.2, 0.2])
         }], jax.random.PRNGKey(0)),
         (b'cid1', [{
             'x': jnp.array([0.1, 0.8])
         }], jax.random.PRNGKey(1)),
     ]
     server_params = {'w': jnp.array(4.)}
     server_opt_state = base_optimizer.init(server_params)
     shared_input = {
         'params': server_params,
         'opt_state': server_opt_state,
     }
     client_outputs = dict(
         train_for_each_client(shared_input, batch_clients))
     npt.assert_allclose(client_outputs[b'cid0']['w'], 0.7)
     npt.assert_allclose(client_outputs[b'cid1']['w'], 0.45000002)
コード例 #13
0
    def test_agnostic_federated_averaging(self):
        algorithm = agnostic_fed_avg.agnostic_federated_averaging(
            per_example_loss=per_example_loss,
            client_optimizer=optimizers.sgd(learning_rate=1.0),
            server_optimizer=optimizers.sgd(learning_rate=0.1),
            client_batch_hparams=client_datasets.ShuffleRepeatBatchHParams(
                batch_size=3, num_epochs=1, seed=0),
            domain_batch_hparams=client_datasets.PaddedBatchHParams(
                batch_size=3),
            init_domain_weights=[0.1, 0.2, 0.3, 0.4],
            domain_learning_rate=0.01,
            domain_algorithm='eg',
            domain_window_size=2,
            init_domain_window=[1., 2., 3., 4.])

        with self.subTest('init'):
            state = algorithm.init({'w': jnp.array(4.)})
            npt.assert_equal(state.params, {'w': jnp.array(4.)})
            self.assertLen(state.opt_state, 2)
            npt.assert_allclose(state.domain_weights, [0.1, 0.2, 0.3, 0.4])
            npt.assert_allclose(state.domain_window,
                                [[1., 2., 3., 4.], [1., 2., 3., 4.]])

        with self.subTest('apply'):
            clients = [
                (b'cid0',
                 client_datasets.ClientDataset({
                     'x':
                     jnp.array([1., 2., 4., 3., 6., 1.]),
                     'domain_id':
                     jnp.array([1, 0, 0, 0, 2, 2])
                 }), jax.random.PRNGKey(0)),
                (b'cid1',
                 client_datasets.ClientDataset({
                     'x':
                     jnp.array([8., 10., 5.]),
                     'domain_id':
                     jnp.array([1, 3, 1])
                 }), jax.random.PRNGKey(1)),
            ]
            next_state, client_diagnostics = algorithm.apply(state, clients)
            npt.assert_allclose(next_state.params['w'], 3.5555556)
            npt.assert_allclose(
                next_state.domain_weights,
                [0.08702461, 0.18604803, 0.2663479, 0.46057943])
            npt.assert_allclose(next_state.domain_window,
                                [[1., 2., 3., 4.], [3., 3., 2., 1.]])
            npt.assert_allclose(client_diagnostics[b'cid0']['delta_l2_norm'],
                                2.8333335)
            npt.assert_allclose(client_diagnostics[b'cid1']['delta_l2_norm'],
                                7.666667)

        with self.subTest('invalid init_domain_weights'):
            with self.assertRaisesRegex(
                    ValueError,
                    'init_domain_weights must sum to approximately 1.'):
                agnostic_fed_avg.agnostic_federated_averaging(
                    per_example_loss=per_example_loss,
                    client_optimizer=optimizers.sgd(learning_rate=1.0),
                    server_optimizer=optimizers.sgd(learning_rate=1.0),
                    client_batch_hparams=client_datasets.
                    ShuffleRepeatBatchHParams(batch_size=3),
                    domain_batch_hparams=client_datasets.PaddedBatchHParams(
                        batch_size=3),
                    init_domain_weights=[50., 0., 0., 0.],
                    domain_learning_rate=0.5)

        with self.subTest('unequal lengths'):
            with self.assertRaisesRegex(
                    ValueError,
                    'init_domain_weights and init_domain_window must be equal lengths.'
            ):
                agnostic_fed_avg.agnostic_federated_averaging(
                    per_example_loss=per_example_loss,
                    client_optimizer=optimizers.sgd(learning_rate=1.0),
                    server_optimizer=optimizers.sgd(learning_rate=1.0),
                    client_batch_hparams=client_datasets.
                    ShuffleRepeatBatchHParams(batch_size=3),
                    domain_batch_hparams=client_datasets.PaddedBatchHParams(
                        batch_size=3),
                    init_domain_weights=[0.1, 0.2, 0.3, 0.4],
                    domain_learning_rate=0.5,
                    init_domain_window=[1, 2])
コード例 #14
0
ファイル: hyp_cluster_test.py プロジェクト: google/fedjax
    def test_hyp_cluster(self):
        functions_called = set()

        def per_example_loss(params, batch, rng):
            self.assertIsNotNone(rng)
            functions_called.add('per_example_loss')
            return jnp.square(params - batch['x'])

        def regularizer(params):
            del params
            functions_called.add('regularizer')
            return 0

        client_optimizer = optimizers.sgd(0.5)
        server_optimizer = optimizers.sgd(0.25)

        maximization_batch_hparams = client_datasets.PaddedBatchHParams(
            batch_size=2)
        expectation_batch_hparams = client_datasets.ShuffleRepeatBatchHParams(
            batch_size=1, num_epochs=5)

        algorithm = hyp_cluster.hyp_cluster(
            per_example_loss=per_example_loss,
            client_optimizer=client_optimizer,
            server_optimizer=server_optimizer,
            maximization_batch_hparams=maximization_batch_hparams,
            expectation_batch_hparams=expectation_batch_hparams,
            regularizer=regularizer)

        init_state = algorithm.init([jnp.array(1.), jnp.array(-1.)])
        # Nothing happens with empty data.
        no_op_state, diagnostics = algorithm.apply(init_state, clients=[])
        npt.assert_array_equal(init_state.cluster_params,
                               no_op_state.cluster_params)
        self.assertEmpty(diagnostics)
        # Some actual training. PRNGKeys are not actually used.
        clients = [
            (b'0', client_datasets.ClientDataset({'x': np.array([1.1])}),
             jax.random.PRNGKey(0)),
            (b'1', client_datasets.ClientDataset({'x': np.array([0.9, 0.9])}),
             jax.random.PRNGKey(1)),
            (b'2', client_datasets.ClientDataset({'x': np.array([-1.1])}),
             jax.random.PRNGKey(2)),
            (b'3',
             client_datasets.ClientDataset({'x': np.array([-0.9, -0.9,
                                                           -0.9])}),
             jax.random.PRNGKey(3)),
        ]
        next_state, diagnostics = algorithm.apply(init_state, clients)
        npt.assert_equal(
            diagnostics, {
                b'0': {
                    'cluster_id': 0
                },
                b'1': {
                    'cluster_id': 0
                },
                b'2': {
                    'cluster_id': 1
                },
                b'3': {
                    'cluster_id': 1
                },
            })
        cluster_params = next_state.cluster_params
        self.assertIsInstance(cluster_params, list)
        self.assertLen(cluster_params, 2)
        npt.assert_allclose(cluster_params[0], [1. - 0.25 * 0.1 / 3])
        npt.assert_allclose(cluster_params[1], [-1. + 0.25 * 0.2 / 4])
        self.assertCountEqual(functions_called,
                              ['per_example_loss', 'regularizer'])
コード例 #15
0
ファイル: hyp_cluster_test.py プロジェクト: google/fedjax
class ClientTrainerTest(absltest.TestCase):

    OPTIMIZER = optimizers.sgd(1)

    def test_train_global_params(self):
        def grad(params, batch, rng):
            return 0.5 * params + jnp.mean(batch['x']) + jax.random.uniform(
                rng, [])

        rng_0 = jax.random.PRNGKey(0)
        rng_uniform_00 = jax.random.uniform(jax.random.split(rng_0)[1], [])
        rng_uniform_01 = jax.random.uniform(
            jax.random.split(jax.random.split(rng_0)[0])[1], [])
        rng_1 = jax.random.PRNGKey(1)
        rng_uniform_10 = jax.random.uniform(jax.random.split(rng_1)[1], [])

        params = jnp.array(1.)
        clients = [(b'0000', [{
            'x': jnp.array([0.125])
        }, {
            'x': jnp.array([0.25, 0.75])
        }], rng_0), (b'1001', [{
            'x': jnp.array([0, 1, 2])
        }], rng_1)]
        client_delta_params = dict(
            hyp_cluster.ClientDeltaTrainer(grad,
                                           self.OPTIMIZER).train_global_params(
                                               params, clients))
        self.assertCountEqual(client_delta_params, [b'0000', b'1001'])
        npt.assert_allclose(
            client_delta_params[b'0000'], (0.5 * 1 + 0.125 + rng_uniform_00) +
            (0.5 * (0.375 - rng_uniform_00) + 0.5 + rng_uniform_01))
        npt.assert_allclose(client_delta_params[b'1001'],
                            (0.5 * 1 + 1) + rng_uniform_10)

    def test_train_per_client_params(self):
        def grad(params, batch, rng):
            return 0.5 * params + jnp.mean(batch['x']) + jax.random.uniform(
                rng, [])

        rng_0 = jax.random.PRNGKey(0)
        rng_uniform_00 = jax.random.uniform(jax.random.split(rng_0)[1], [])
        rng_uniform_01 = jax.random.uniform(
            jax.random.split(jax.random.split(rng_0)[0])[1], [])
        rng_1 = jax.random.PRNGKey(1)
        rng_uniform_10 = jax.random.uniform(jax.random.split(rng_1)[1], [])

        clients = [
            (b'0000', [{
                'x': jnp.array([0.125])
            }, {
                'x': jnp.array([0.25, 0.75])
            }], rng_0, jnp.array(1.)),
            (b'1001', [{
                'x': jnp.array([0, 1, 2])
            }], rng_1, jnp.array(0.5)),
        ]
        client_delta_params = dict(
            hyp_cluster.ClientDeltaTrainer(
                grad, self.OPTIMIZER).train_per_client_params(clients))
        self.assertCountEqual(client_delta_params, [b'0000', b'1001'])
        npt.assert_allclose(
            client_delta_params[b'0000'], (0.5 * 1 + 0.125 + rng_uniform_00) +
            (0.5 * (0.375 - rng_uniform_00) + 0.5 + rng_uniform_01))
        npt.assert_allclose(client_delta_params[b'1001'],
                            (0.5 * 0.5 + 1) + rng_uniform_10)

    def test_return_params(self):
        def grad(params, batch, _):
            return 0.5 * params + jnp.mean(batch['x'])

        params = jnp.array(1.)
        clients = [(b'0000', [{
            'x': jnp.array([0.125])
        }, {
            'x': jnp.array([0.25, 0.75])
        }], jax.random.PRNGKey(0)),
                   (b'1001', [{
                       'x': jnp.array([0, 1, 2])
                   }], jax.random.PRNGKey(1))]
        client_delta_params = dict(
            hyp_cluster.ClientParamsTrainer(
                grad, self.OPTIMIZER).train_global_params(params, clients))
        npt.assert_equal(
            jax.device_get(client_delta_params), {
                b'0000': 1 - ((0.5 * 1 + 0.125) + (0.5 * 0.375 + 0.5)),
                b'1001': 1 - (0.5 * 1 + 1)
            })
コード例 #16
0
ファイル: hyp_cluster_test.py プロジェクト: google/fedjax
    def test_kmeans_init(self):
        functions_called = set()

        def init(rng):
            functions_called.add('init')
            return jax.random.uniform(rng)

        def apply_for_train(params, batch, rng):
            functions_called.add('apply_for_train')
            self.assertIsNotNone(rng)
            return params - batch['x']

        def train_loss(batch, out):
            functions_called.add('train_loss')
            return jnp.square(out) + batch['bias']

        def regularizer(params):
            del params
            functions_called.add('regularizer')
            return 0

        initializer = hyp_cluster.ModelKMeansInitializer(
            models.Model(init=init,
                         apply_for_train=apply_for_train,
                         apply_for_eval=None,
                         train_loss=train_loss,
                         eval_metrics={}), optimizers.sgd(0.5), regularizer)
        # Each client has 1 example, so it's very easy to reach minimal loss, at
        # which point the loss entirely depends on bias.
        clients = [
            (b'0',
             client_datasets.ClientDataset({
                 'x': np.array([1.01]),
                 'bias': np.array([-2.])
             }), jax.random.PRNGKey(1)),
            (b'1',
             client_datasets.ClientDataset({
                 'x': np.array([3.02]),
                 'bias': np.array([-1.])
             }), jax.random.PRNGKey(2)),
            (b'2',
             client_datasets.ClientDataset({
                 'x': np.array([3.03]),
                 'bias': np.array([1.])
             }), jax.random.PRNGKey(3)),
            (b'3',
             client_datasets.ClientDataset({
                 'x': np.array([1.04]),
                 'bias': np.array([2.])
             }), jax.random.PRNGKey(3)),
        ]
        train_batch_hparams = client_datasets.ShuffleRepeatBatchHParams(
            batch_size=1, num_epochs=5)
        eval_batch_hparams = client_datasets.PaddedBatchHParams(batch_size=2)
        # Using a rng that leads to b'0' being the initial center.
        cluster_params = initializer.cluster_params(
            num_clusters=3,
            rng=jax.random.PRNGKey(0),
            clients=clients,
            train_batch_hparams=train_batch_hparams,
            eval_batch_hparams=eval_batch_hparams)
        self.assertIsInstance(cluster_params, list)
        self.assertLen(cluster_params, 3)
        npt.assert_allclose(cluster_params, [1.01, 3.03, 1.04])
        self.assertCountEqual(
            functions_called,
            ['init', 'apply_for_train', 'train_loss', 'regularizer'])
        # Using a rng that leads to b'2' being the initial center.
        cluster_params = initializer.cluster_params(
            num_clusters=3,
            rng=jax.random.PRNGKey(1),
            clients=clients,
            train_batch_hparams=train_batch_hparams,
            eval_batch_hparams=eval_batch_hparams)
        self.assertIsInstance(cluster_params, list)
        self.assertLen(cluster_params, 3)
        npt.assert_allclose(cluster_params, [3.03, 1.04, 1.04])