예제 #1
0
파일: mime_test.py 프로젝트: google/fedjax
  def test_mime(self):
    base_optimizer = optimizers.sgd(learning_rate=1.0)
    train_batch_hparams = client_datasets.ShuffleRepeatBatchHParams(
        batch_size=2, num_epochs=1, seed=0)
    grad_batch_hparams = client_datasets.PaddedBatchHParams(batch_size=2)
    server_learning_rate = 0.2
    algorithm = mime.mime(per_example_loss, base_optimizer, train_batch_hparams,
                          grad_batch_hparams, server_learning_rate)

    with self.subTest('init'):
      state = algorithm.init({'w': jnp.array(4.)})
      npt.assert_equal(state.params, {'w': jnp.array(4.)})
      self.assertLen(state.opt_state, 2)

    with self.subTest('apply'):
      clients = [
          (b'cid0',
           client_datasets.ClientDataset({'x': jnp.array([2., 4., 6.])}),
           jax.random.PRNGKey(0)),
          (b'cid1', client_datasets.ClientDataset({'x': jnp.array([8., 10.])}),
           jax.random.PRNGKey(1)),
      ]
      state, client_diagnostics = algorithm.apply(state, clients)
      npt.assert_allclose(state.params['w'], 2.08)
      npt.assert_allclose(client_diagnostics[b'cid0']['delta_l2_norm'], 12.)
      npt.assert_allclose(client_diagnostics[b'cid1']['delta_l2_norm'], 6.)
예제 #2
0
    def test_federated_averaging(self):
        client_optimizer = optimizers.sgd(learning_rate=1.0)
        server_optimizer = optimizers.sgd(learning_rate=1.0)
        client_batch_hparams = client_datasets.ShuffleRepeatBatchHParams(
            batch_size=2, num_epochs=1, seed=0)
        algorithm = fed_avg.federated_averaging(grad_fn, client_optimizer,
                                                server_optimizer,
                                                client_batch_hparams)

        with self.subTest('init'):
            state = algorithm.init({'w': jnp.array([0., 2., 4.])})
            npt.assert_array_equal(state.params['w'], [0., 2., 4.])
            self.assertLen(state.opt_state, 2)

        with self.subTest('apply'):
            clients = [
                (b'cid0',
                 client_datasets.ClientDataset({'x': jnp.array([2., 4., 6.])}),
                 jax.random.PRNGKey(0)),
                (b'cid1',
                 client_datasets.ClientDataset({'x': jnp.array([8., 10.])}),
                 jax.random.PRNGKey(1)),
            ]
            state, client_diagnostics = algorithm.apply(state, clients)
            npt.assert_allclose(state.params['w'], [0., 1.5655555, 3.131111])
            npt.assert_allclose(client_diagnostics[b'cid0']['delta_l2_norm'],
                                1.4534444262)
            npt.assert_allclose(client_diagnostics[b'cid1']['delta_l2_norm'],
                                0.2484521282)
예제 #3
0
    def test_client_delta_clip_norm(self):
        base_optimizer = optimizers.sgd(learning_rate=1.0)
        train_batch_hparams = client_datasets.ShuffleRepeatBatchHParams(
            batch_size=2, num_epochs=1, seed=0)
        grad_batch_hparams = client_datasets.PaddedBatchHParams(batch_size=2)
        server_learning_rate = 0.2
        algorithm = mime_lite.mime_lite(per_example_loss,
                                        base_optimizer,
                                        train_batch_hparams,
                                        grad_batch_hparams,
                                        server_learning_rate,
                                        client_delta_clip_norm=0.5)

        clients = [
            (b'cid0',
             client_datasets.ClientDataset({'x': jnp.array([0.2, 0.4, 0.6])}),
             jax.random.PRNGKey(0)),
            (b'cid1',
             client_datasets.ClientDataset({'x': jnp.array([0.8, 0.1])}),
             jax.random.PRNGKey(1)),
        ]
        state = algorithm.init({'w': jnp.array(4.)})
        state, client_diagnostics = algorithm.apply(state, clients)
        npt.assert_allclose(state.params['w'], 3.904)
        npt.assert_allclose(
            client_diagnostics[b'cid0']['clipped_delta_l2_norm'], 0.5)
        npt.assert_allclose(
            client_diagnostics[b'cid1']['clipped_delta_l2_norm'], 0.45000005)
예제 #4
0
    def test_fed_prox(self):
        client_optimizer = optimizers.sgd(learning_rate=1.0)
        server_optimizer = optimizers.sgd(learning_rate=1.0)
        client_batch_hparams = client_datasets.ShuffleRepeatBatchHParams(
            batch_size=2, num_epochs=1, seed=0)
        algorithm = fed_prox.fed_prox(per_example_loss,
                                      client_optimizer,
                                      server_optimizer,
                                      client_batch_hparams,
                                      proximal_weight=0.01)

        with self.subTest('init'):
            state = algorithm.init({'w': jnp.array(4.)})
            npt.assert_array_equal(state.params['w'], 4.)
            self.assertLen(state.opt_state, 2)

        with self.subTest('apply'):
            clients = [
                (b'cid0',
                 client_datasets.ClientDataset({'x': jnp.array([2., 4., 6.])}),
                 jax.random.PRNGKey(0)),
                (b'cid1',
                 client_datasets.ClientDataset({'x': jnp.array([8., 10.])}),
                 jax.random.PRNGKey(1)),
            ]
            state, client_diagnostics = algorithm.apply(state, clients)
            npt.assert_allclose(state.params['w'], -3.77)
            npt.assert_allclose(client_diagnostics[b'cid0']['delta_l2_norm'],
                                6.95)
            npt.assert_allclose(client_diagnostics[b'cid1']['delta_l2_norm'],
                                9.)
예제 #5
0
 def __init__(self,
              name: Optional[str] = None,
              default_batch_size: int = 128):
     super().__init__(name)
     defaults = client_datasets.ShuffleRepeatBatchHParams(batch_size=-1)
     # TODO(wuke): Support other fields.
     self._integer('batch_size', default_batch_size, 'Batch size')
     self._integer('num_epochs', defaults.num_epochs, 'Number of epochs')
     self._integer('num_steps', defaults.num_steps, 'Number of steps')
예제 #6
0
 def test_get(self):
     with self.subTest('default'):
         self.assertEqual(
             self.SHUFFLE_BATCH_REPEAT.get(),
             client_datasets.ShuffleRepeatBatchHParams(batch_size=128))
         self.assertEqual(
             self.PADDED_BATCH.get(),
             client_datasets.PaddedBatchHParams(batch_size=128))
         self.assertEqual(self.BATCH.get(),
                          client_datasets.BatchHParams(batch_size=128))
     with self.subTest('custom'):
         with flagsaver.flagsaver(shuffle_batch_repeat_batch_size=12,
                                  shuffle_batch_repeat_num_epochs=21,
                                  padded_batch_size=34,
                                  batch_size=56):
             self.assertEqual(
                 self.SHUFFLE_BATCH_REPEAT.get(),
                 client_datasets.ShuffleRepeatBatchHParams(batch_size=12,
                                                           num_epochs=21))
             self.assertEqual(
                 self.PADDED_BATCH.get(),
                 client_datasets.PaddedBatchHParams(batch_size=34))
             self.assertEqual(self.BATCH.get(),
                              client_datasets.BatchHParams(batch_size=56))
예제 #7
0
    def test_expectation_step(self):
        def per_example_loss(params, batch, rng):
            self.assertIsNotNone(rng)
            return jnp.square(params - batch['x'])

        trainer = hyp_cluster.ClientDeltaTrainer(models.grad(per_example_loss),
                                                 optimizers.sgd(0.5))
        batch_hparams = client_datasets.ShuffleRepeatBatchHParams(batch_size=1,
                                                                  num_epochs=5)

        cluster_params = [jnp.array(1.), jnp.array(-1.), jnp.array(3.14)]
        client_cluster_ids = {b'0': 0, b'1': 0, b'2': 1, b'3': 1, b'4': 0}
        # RNGs are not actually used.
        clients = [
            (b'0', client_datasets.ClientDataset({'x': np.array([1.1])}),
             jax.random.PRNGKey(0)),
            (b'1', client_datasets.ClientDataset({'x': np.array([0.9, 0.9])}),
             jax.random.PRNGKey(1)),
            (b'2', client_datasets.ClientDataset({'x': np.array([-1.1])}),
             jax.random.PRNGKey(2)),
            (b'3',
             client_datasets.ClientDataset({'x': np.array([-0.9, -0.9,
                                                           -0.9])}),
             jax.random.PRNGKey(3)),
            (b'4', client_datasets.ClientDataset({'x': np.array([-0.1])}),
             jax.random.PRNGKey(4)),
        ]
        cluster_delta_params = hyp_cluster.expectation_step(
            trainer=trainer,
            cluster_params=cluster_params,
            client_cluster_ids=client_cluster_ids,
            clients=clients,
            batch_hparams=batch_hparams)
        self.assertIsInstance(cluster_delta_params, list)
        self.assertLen(cluster_delta_params, 3)
        npt.assert_allclose(cluster_delta_params[0],
                            (-0.1 + 0.1 * 2 + 1.1) / 4)
        npt.assert_allclose(cluster_delta_params[1], (0.1 - 0.1 * 3) / 4,
                            rtol=1e-6)
        self.assertIsNone(cluster_delta_params[2])
예제 #8
0
    def test_agnostic_federated_averaging(self):
        algorithm = agnostic_fed_avg.agnostic_federated_averaging(
            per_example_loss=per_example_loss,
            client_optimizer=optimizers.sgd(learning_rate=1.0),
            server_optimizer=optimizers.sgd(learning_rate=0.1),
            client_batch_hparams=client_datasets.ShuffleRepeatBatchHParams(
                batch_size=3, num_epochs=1, seed=0),
            domain_batch_hparams=client_datasets.PaddedBatchHParams(
                batch_size=3),
            init_domain_weights=[0.1, 0.2, 0.3, 0.4],
            domain_learning_rate=0.01,
            domain_algorithm='eg',
            domain_window_size=2,
            init_domain_window=[1., 2., 3., 4.])

        with self.subTest('init'):
            state = algorithm.init({'w': jnp.array(4.)})
            npt.assert_equal(state.params, {'w': jnp.array(4.)})
            self.assertLen(state.opt_state, 2)
            npt.assert_allclose(state.domain_weights, [0.1, 0.2, 0.3, 0.4])
            npt.assert_allclose(state.domain_window,
                                [[1., 2., 3., 4.], [1., 2., 3., 4.]])

        with self.subTest('apply'):
            clients = [
                (b'cid0',
                 client_datasets.ClientDataset({
                     'x':
                     jnp.array([1., 2., 4., 3., 6., 1.]),
                     'domain_id':
                     jnp.array([1, 0, 0, 0, 2, 2])
                 }), jax.random.PRNGKey(0)),
                (b'cid1',
                 client_datasets.ClientDataset({
                     'x':
                     jnp.array([8., 10., 5.]),
                     'domain_id':
                     jnp.array([1, 3, 1])
                 }), jax.random.PRNGKey(1)),
            ]
            next_state, client_diagnostics = algorithm.apply(state, clients)
            npt.assert_allclose(next_state.params['w'], 3.5555556)
            npt.assert_allclose(
                next_state.domain_weights,
                [0.08702461, 0.18604803, 0.2663479, 0.46057943])
            npt.assert_allclose(next_state.domain_window,
                                [[1., 2., 3., 4.], [3., 3., 2., 1.]])
            npt.assert_allclose(client_diagnostics[b'cid0']['delta_l2_norm'],
                                2.8333335)
            npt.assert_allclose(client_diagnostics[b'cid1']['delta_l2_norm'],
                                7.666667)

        with self.subTest('invalid init_domain_weights'):
            with self.assertRaisesRegex(
                    ValueError,
                    'init_domain_weights must sum to approximately 1.'):
                agnostic_fed_avg.agnostic_federated_averaging(
                    per_example_loss=per_example_loss,
                    client_optimizer=optimizers.sgd(learning_rate=1.0),
                    server_optimizer=optimizers.sgd(learning_rate=1.0),
                    client_batch_hparams=client_datasets.
                    ShuffleRepeatBatchHParams(batch_size=3),
                    domain_batch_hparams=client_datasets.PaddedBatchHParams(
                        batch_size=3),
                    init_domain_weights=[50., 0., 0., 0.],
                    domain_learning_rate=0.5)

        with self.subTest('unequal lengths'):
            with self.assertRaisesRegex(
                    ValueError,
                    'init_domain_weights and init_domain_window must be equal lengths.'
            ):
                agnostic_fed_avg.agnostic_federated_averaging(
                    per_example_loss=per_example_loss,
                    client_optimizer=optimizers.sgd(learning_rate=1.0),
                    server_optimizer=optimizers.sgd(learning_rate=1.0),
                    client_batch_hparams=client_datasets.
                    ShuffleRepeatBatchHParams(batch_size=3),
                    domain_batch_hparams=client_datasets.PaddedBatchHParams(
                        batch_size=3),
                    init_domain_weights=[0.1, 0.2, 0.3, 0.4],
                    domain_learning_rate=0.5,
                    init_domain_window=[1, 2])
예제 #9
0
    def test_hyp_cluster(self):
        functions_called = set()

        def per_example_loss(params, batch, rng):
            self.assertIsNotNone(rng)
            functions_called.add('per_example_loss')
            return jnp.square(params - batch['x'])

        def regularizer(params):
            del params
            functions_called.add('regularizer')
            return 0

        client_optimizer = optimizers.sgd(0.5)
        server_optimizer = optimizers.sgd(0.25)

        maximization_batch_hparams = client_datasets.PaddedBatchHParams(
            batch_size=2)
        expectation_batch_hparams = client_datasets.ShuffleRepeatBatchHParams(
            batch_size=1, num_epochs=5)

        algorithm = hyp_cluster.hyp_cluster(
            per_example_loss=per_example_loss,
            client_optimizer=client_optimizer,
            server_optimizer=server_optimizer,
            maximization_batch_hparams=maximization_batch_hparams,
            expectation_batch_hparams=expectation_batch_hparams,
            regularizer=regularizer)

        init_state = algorithm.init([jnp.array(1.), jnp.array(-1.)])
        # Nothing happens with empty data.
        no_op_state, diagnostics = algorithm.apply(init_state, clients=[])
        npt.assert_array_equal(init_state.cluster_params,
                               no_op_state.cluster_params)
        self.assertEmpty(diagnostics)
        # Some actual training. PRNGKeys are not actually used.
        clients = [
            (b'0', client_datasets.ClientDataset({'x': np.array([1.1])}),
             jax.random.PRNGKey(0)),
            (b'1', client_datasets.ClientDataset({'x': np.array([0.9, 0.9])}),
             jax.random.PRNGKey(1)),
            (b'2', client_datasets.ClientDataset({'x': np.array([-1.1])}),
             jax.random.PRNGKey(2)),
            (b'3',
             client_datasets.ClientDataset({'x': np.array([-0.9, -0.9,
                                                           -0.9])}),
             jax.random.PRNGKey(3)),
        ]
        next_state, diagnostics = algorithm.apply(init_state, clients)
        npt.assert_equal(
            diagnostics, {
                b'0': {
                    'cluster_id': 0
                },
                b'1': {
                    'cluster_id': 0
                },
                b'2': {
                    'cluster_id': 1
                },
                b'3': {
                    'cluster_id': 1
                },
            })
        cluster_params = next_state.cluster_params
        self.assertIsInstance(cluster_params, list)
        self.assertLen(cluster_params, 2)
        npt.assert_allclose(cluster_params[0], [1. - 0.25 * 0.1 / 3])
        npt.assert_allclose(cluster_params[1], [-1. + 0.25 * 0.2 / 4])
        self.assertCountEqual(functions_called,
                              ['per_example_loss', 'regularizer'])
예제 #10
0
    def test_kmeans_init(self):
        functions_called = set()

        def init(rng):
            functions_called.add('init')
            return jax.random.uniform(rng)

        def apply_for_train(params, batch, rng):
            functions_called.add('apply_for_train')
            self.assertIsNotNone(rng)
            return params - batch['x']

        def train_loss(batch, out):
            functions_called.add('train_loss')
            return jnp.square(out) + batch['bias']

        def regularizer(params):
            del params
            functions_called.add('regularizer')
            return 0

        initializer = hyp_cluster.ModelKMeansInitializer(
            models.Model(init=init,
                         apply_for_train=apply_for_train,
                         apply_for_eval=None,
                         train_loss=train_loss,
                         eval_metrics={}), optimizers.sgd(0.5), regularizer)
        # Each client has 1 example, so it's very easy to reach minimal loss, at
        # which point the loss entirely depends on bias.
        clients = [
            (b'0',
             client_datasets.ClientDataset({
                 'x': np.array([1.01]),
                 'bias': np.array([-2.])
             }), jax.random.PRNGKey(1)),
            (b'1',
             client_datasets.ClientDataset({
                 'x': np.array([3.02]),
                 'bias': np.array([-1.])
             }), jax.random.PRNGKey(2)),
            (b'2',
             client_datasets.ClientDataset({
                 'x': np.array([3.03]),
                 'bias': np.array([1.])
             }), jax.random.PRNGKey(3)),
            (b'3',
             client_datasets.ClientDataset({
                 'x': np.array([1.04]),
                 'bias': np.array([2.])
             }), jax.random.PRNGKey(3)),
        ]
        train_batch_hparams = client_datasets.ShuffleRepeatBatchHParams(
            batch_size=1, num_epochs=5)
        eval_batch_hparams = client_datasets.PaddedBatchHParams(batch_size=2)
        # Using a rng that leads to b'0' being the initial center.
        cluster_params = initializer.cluster_params(
            num_clusters=3,
            rng=jax.random.PRNGKey(0),
            clients=clients,
            train_batch_hparams=train_batch_hparams,
            eval_batch_hparams=eval_batch_hparams)
        self.assertIsInstance(cluster_params, list)
        self.assertLen(cluster_params, 3)
        npt.assert_allclose(cluster_params, [1.01, 3.03, 1.04])
        self.assertCountEqual(
            functions_called,
            ['init', 'apply_for_train', 'train_loss', 'regularizer'])
        # Using a rng that leads to b'2' being the initial center.
        cluster_params = initializer.cluster_params(
            num_clusters=3,
            rng=jax.random.PRNGKey(1),
            clients=clients,
            train_batch_hparams=train_batch_hparams,
            eval_batch_hparams=eval_batch_hparams)
        self.assertIsInstance(cluster_params, list)
        self.assertLen(cluster_params, 3)
        npt.assert_allclose(cluster_params, [3.03, 1.04, 1.04])
예제 #11
0
 def get(self):
     return client_datasets.ShuffleRepeatBatchHParams(
         batch_size=self._get_flag('batch_size'),
         num_epochs=self._get_flag('num_epochs'),
         num_steps=self._get_flag('num_steps'))