예제 #1
0
    def test_client_delta_clip_norm(self):
        base_optimizer = optimizers.sgd(learning_rate=1.0)
        train_batch_hparams = client_datasets.ShuffleRepeatBatchHParams(
            batch_size=2, num_epochs=1, seed=0)
        grad_batch_hparams = client_datasets.PaddedBatchHParams(batch_size=2)
        server_learning_rate = 0.2
        algorithm = mime_lite.mime_lite(per_example_loss,
                                        base_optimizer,
                                        train_batch_hparams,
                                        grad_batch_hparams,
                                        server_learning_rate,
                                        client_delta_clip_norm=0.5)

        clients = [
            (b'cid0',
             client_datasets.ClientDataset({'x': jnp.array([0.2, 0.4, 0.6])}),
             jax.random.PRNGKey(0)),
            (b'cid1',
             client_datasets.ClientDataset({'x': jnp.array([0.8, 0.1])}),
             jax.random.PRNGKey(1)),
        ]
        state = algorithm.init({'w': jnp.array(4.)})
        state, client_diagnostics = algorithm.apply(state, clients)
        npt.assert_allclose(state.params['w'], 3.904)
        npt.assert_allclose(
            client_diagnostics[b'cid0']['clipped_delta_l2_norm'], 0.5)
        npt.assert_allclose(
            client_diagnostics[b'cid1']['clipped_delta_l2_norm'], 0.45000005)
예제 #2
0
파일: mime_test.py 프로젝트: google/fedjax
  def test_mime(self):
    base_optimizer = optimizers.sgd(learning_rate=1.0)
    train_batch_hparams = client_datasets.ShuffleRepeatBatchHParams(
        batch_size=2, num_epochs=1, seed=0)
    grad_batch_hparams = client_datasets.PaddedBatchHParams(batch_size=2)
    server_learning_rate = 0.2
    algorithm = mime.mime(per_example_loss, base_optimizer, train_batch_hparams,
                          grad_batch_hparams, server_learning_rate)

    with self.subTest('init'):
      state = algorithm.init({'w': jnp.array(4.)})
      npt.assert_equal(state.params, {'w': jnp.array(4.)})
      self.assertLen(state.opt_state, 2)

    with self.subTest('apply'):
      clients = [
          (b'cid0',
           client_datasets.ClientDataset({'x': jnp.array([2., 4., 6.])}),
           jax.random.PRNGKey(0)),
          (b'cid1', client_datasets.ClientDataset({'x': jnp.array([8., 10.])}),
           jax.random.PRNGKey(1)),
      ]
      state, client_diagnostics = algorithm.apply(state, clients)
      npt.assert_allclose(state.params['w'], 2.08)
      npt.assert_allclose(client_diagnostics[b'cid0']['delta_l2_norm'], 12.)
      npt.assert_allclose(client_diagnostics[b'cid1']['delta_l2_norm'], 6.)
예제 #3
0
 def test_model_train_clients_evaluation_fn(self):
   sampler = FakeClientSampler()
   sampler.set_round_num(1)
   clients = [sampler.sample()[0] for _ in range(4)]
   eval_fn = federated_experiment.ModelTrainClientsEvaluationFn(
       fake_model(), client_datasets.PaddedBatchHParams(batch_size=4))
   state = FakeState()
   npt.assert_equal(eval_fn(state, 1, clients), {'accuracy': np.array(0.25)})
   npt.assert_equal(eval_fn(state, 100, clients), {'accuracy': np.array(0.25)})
예제 #4
0
 def test_model_full_evaluation_fn(self):
   sampler = FakeClientSampler()
   sampler.set_round_num(1)
   clients = [sampler.sample()[0] for _ in range(4)]
   fd = in_memory_federated_data.InMemoryFederatedData(
       dict((k, v.all_examples()) for k, v, _ in clients))
   eval_fn = federated_experiment.ModelFullEvaluationFn(
       fd, fake_model(), client_datasets.PaddedBatchHParams(batch_size=4))
   state = FakeState()
   npt.assert_equal(eval_fn(state, 1), {'accuracy': np.array(0.25)})
   npt.assert_equal(eval_fn(state, 100), {'accuracy': np.array(0.25)})
예제 #5
0
 def test_model_sample_clients_evaluation_fn(self):
   eval_fn = federated_experiment.ModelSampleClientsEvaluationFn(
       FakeClientSampler(), fake_model(),
       client_datasets.PaddedBatchHParams(batch_size=4))
   state = FakeState()
   npt.assert_equal(eval_fn(state, 1), {'accuracy': np.array(1.)})
   npt.assert_equal(eval_fn(state, 2), {'accuracy': np.array(0.)})
   npt.assert_equal(eval_fn(state, 3), {'accuracy': np.array(0.)})
   npt.assert_equal(eval_fn(state, 4), {'accuracy': np.array(0.)})
   npt.assert_equal(eval_fn(state, 5), {'accuracy': np.array(0.)})
   npt.assert_equal(eval_fn(state, 6), {'accuracy': np.array(1.)})
예제 #6
0
 def test_get(self):
     with self.subTest('default'):
         self.assertEqual(
             self.SHUFFLE_BATCH_REPEAT.get(),
             client_datasets.ShuffleRepeatBatchHParams(batch_size=128))
         self.assertEqual(
             self.PADDED_BATCH.get(),
             client_datasets.PaddedBatchHParams(batch_size=128))
         self.assertEqual(self.BATCH.get(),
                          client_datasets.BatchHParams(batch_size=128))
     with self.subTest('custom'):
         with flagsaver.flagsaver(shuffle_batch_repeat_batch_size=12,
                                  shuffle_batch_repeat_num_epochs=21,
                                  padded_batch_size=34,
                                  batch_size=56):
             self.assertEqual(
                 self.SHUFFLE_BATCH_REPEAT.get(),
                 client_datasets.ShuffleRepeatBatchHParams(batch_size=12,
                                                           num_epochs=21))
             self.assertEqual(
                 self.PADDED_BATCH.get(),
                 client_datasets.PaddedBatchHParams(batch_size=34))
             self.assertEqual(self.BATCH.get(),
                              client_datasets.BatchHParams(batch_size=56))
예제 #7
0
 def test_padded_batch(self):
   d = client_datasets.ClientDataset(
       {
           'a': np.arange(5),
           'b': np.arange(10).reshape([5, 2])
       }, client_datasets.BatchPreprocessor([lambda x: {
           **x, 'a': 2 * x['a']
       }]))
   with self.subTest('1 bucket, kwargs'):
     view = d.padded_batch(batch_size=3)
     # `view` should be repeatedly iterable.
     for _ in range(2):
       batches = list(view)
       self.assertLen(batches, 2)
       npt.assert_equal(
           batches[0], {
               'a': [0, 2, 4],
               'b': [[0, 1], [2, 3], [4, 5]],
               '__mask__': [True, True, True],
           })
       npt.assert_equal(
           batches[1], {
               'a': [6, 8, 0],
               'b': [[6, 7], [8, 9], [0, 0]],
               '__mask__': [True, True, False]
           })
   with self.subTest('2 buckets, kwargs override'):
     view = d.padded_batch(
         client_datasets.PaddedBatchHParams(batch_size=4),
         num_batch_size_buckets=2)
     # `view` should be repeatedly iterable.
     for _ in range(2):
       batches = list(view)
       self.assertLen(batches, 2)
       npt.assert_equal(
           batches[0], {
               'a': [0, 2, 4, 6],
               'b': [[0, 1], [2, 3], [4, 5], [6, 7]],
               '__mask__': [True, True, True, True],
           })
       npt.assert_equal(batches[1], {
           'a': [8, 0],
           'b': [[8, 9], [0, 0]],
           '__mask__': [True, False]
       })
예제 #8
0
    def test_agnostic_federated_averaging(self):
        algorithm = agnostic_fed_avg.agnostic_federated_averaging(
            per_example_loss=per_example_loss,
            client_optimizer=optimizers.sgd(learning_rate=1.0),
            server_optimizer=optimizers.sgd(learning_rate=0.1),
            client_batch_hparams=client_datasets.ShuffleRepeatBatchHParams(
                batch_size=3, num_epochs=1, seed=0),
            domain_batch_hparams=client_datasets.PaddedBatchHParams(
                batch_size=3),
            init_domain_weights=[0.1, 0.2, 0.3, 0.4],
            domain_learning_rate=0.01,
            domain_algorithm='eg',
            domain_window_size=2,
            init_domain_window=[1., 2., 3., 4.])

        with self.subTest('init'):
            state = algorithm.init({'w': jnp.array(4.)})
            npt.assert_equal(state.params, {'w': jnp.array(4.)})
            self.assertLen(state.opt_state, 2)
            npt.assert_allclose(state.domain_weights, [0.1, 0.2, 0.3, 0.4])
            npt.assert_allclose(state.domain_window,
                                [[1., 2., 3., 4.], [1., 2., 3., 4.]])

        with self.subTest('apply'):
            clients = [
                (b'cid0',
                 client_datasets.ClientDataset({
                     'x':
                     jnp.array([1., 2., 4., 3., 6., 1.]),
                     'domain_id':
                     jnp.array([1, 0, 0, 0, 2, 2])
                 }), jax.random.PRNGKey(0)),
                (b'cid1',
                 client_datasets.ClientDataset({
                     'x':
                     jnp.array([8., 10., 5.]),
                     'domain_id':
                     jnp.array([1, 3, 1])
                 }), jax.random.PRNGKey(1)),
            ]
            next_state, client_diagnostics = algorithm.apply(state, clients)
            npt.assert_allclose(next_state.params['w'], 3.5555556)
            npt.assert_allclose(
                next_state.domain_weights,
                [0.08702461, 0.18604803, 0.2663479, 0.46057943])
            npt.assert_allclose(next_state.domain_window,
                                [[1., 2., 3., 4.], [3., 3., 2., 1.]])
            npt.assert_allclose(client_diagnostics[b'cid0']['delta_l2_norm'],
                                2.8333335)
            npt.assert_allclose(client_diagnostics[b'cid1']['delta_l2_norm'],
                                7.666667)

        with self.subTest('invalid init_domain_weights'):
            with self.assertRaisesRegex(
                    ValueError,
                    'init_domain_weights must sum to approximately 1.'):
                agnostic_fed_avg.agnostic_federated_averaging(
                    per_example_loss=per_example_loss,
                    client_optimizer=optimizers.sgd(learning_rate=1.0),
                    server_optimizer=optimizers.sgd(learning_rate=1.0),
                    client_batch_hparams=client_datasets.
                    ShuffleRepeatBatchHParams(batch_size=3),
                    domain_batch_hparams=client_datasets.PaddedBatchHParams(
                        batch_size=3),
                    init_domain_weights=[50., 0., 0., 0.],
                    domain_learning_rate=0.5)

        with self.subTest('unequal lengths'):
            with self.assertRaisesRegex(
                    ValueError,
                    'init_domain_weights and init_domain_window must be equal lengths.'
            ):
                agnostic_fed_avg.agnostic_federated_averaging(
                    per_example_loss=per_example_loss,
                    client_optimizer=optimizers.sgd(learning_rate=1.0),
                    server_optimizer=optimizers.sgd(learning_rate=1.0),
                    client_batch_hparams=client_datasets.
                    ShuffleRepeatBatchHParams(batch_size=3),
                    domain_batch_hparams=client_datasets.PaddedBatchHParams(
                        batch_size=3),
                    init_domain_weights=[0.1, 0.2, 0.3, 0.4],
                    domain_learning_rate=0.5,
                    init_domain_window=[1, 2])
예제 #9
0
    def test_hyp_cluster(self):
        functions_called = set()

        def per_example_loss(params, batch, rng):
            self.assertIsNotNone(rng)
            functions_called.add('per_example_loss')
            return jnp.square(params - batch['x'])

        def regularizer(params):
            del params
            functions_called.add('regularizer')
            return 0

        client_optimizer = optimizers.sgd(0.5)
        server_optimizer = optimizers.sgd(0.25)

        maximization_batch_hparams = client_datasets.PaddedBatchHParams(
            batch_size=2)
        expectation_batch_hparams = client_datasets.ShuffleRepeatBatchHParams(
            batch_size=1, num_epochs=5)

        algorithm = hyp_cluster.hyp_cluster(
            per_example_loss=per_example_loss,
            client_optimizer=client_optimizer,
            server_optimizer=server_optimizer,
            maximization_batch_hparams=maximization_batch_hparams,
            expectation_batch_hparams=expectation_batch_hparams,
            regularizer=regularizer)

        init_state = algorithm.init([jnp.array(1.), jnp.array(-1.)])
        # Nothing happens with empty data.
        no_op_state, diagnostics = algorithm.apply(init_state, clients=[])
        npt.assert_array_equal(init_state.cluster_params,
                               no_op_state.cluster_params)
        self.assertEmpty(diagnostics)
        # Some actual training. PRNGKeys are not actually used.
        clients = [
            (b'0', client_datasets.ClientDataset({'x': np.array([1.1])}),
             jax.random.PRNGKey(0)),
            (b'1', client_datasets.ClientDataset({'x': np.array([0.9, 0.9])}),
             jax.random.PRNGKey(1)),
            (b'2', client_datasets.ClientDataset({'x': np.array([-1.1])}),
             jax.random.PRNGKey(2)),
            (b'3',
             client_datasets.ClientDataset({'x': np.array([-0.9, -0.9,
                                                           -0.9])}),
             jax.random.PRNGKey(3)),
        ]
        next_state, diagnostics = algorithm.apply(init_state, clients)
        npt.assert_equal(
            diagnostics, {
                b'0': {
                    'cluster_id': 0
                },
                b'1': {
                    'cluster_id': 0
                },
                b'2': {
                    'cluster_id': 1
                },
                b'3': {
                    'cluster_id': 1
                },
            })
        cluster_params = next_state.cluster_params
        self.assertIsInstance(cluster_params, list)
        self.assertLen(cluster_params, 2)
        npt.assert_allclose(cluster_params[0], [1. - 0.25 * 0.1 / 3])
        npt.assert_allclose(cluster_params[1], [-1. + 0.25 * 0.2 / 4])
        self.assertCountEqual(functions_called,
                              ['per_example_loss', 'regularizer'])
예제 #10
0
    def test_hyp_cluster_evaluator(self):
        functions_called = set()

        def apply_for_eval(params, batch):
            functions_called.add('apply_for_eval')
            score = params * batch['x']
            return jnp.stack([-score, score], axis=-1)

        def apply_for_train(params, batch, rng):
            functions_called.add('apply_for_train')
            self.assertIsNotNone(rng)
            return params * batch['x']

        def train_loss(batch, out):
            functions_called.add('train_loss')
            return jnp.abs(batch['y'] * 2 - 1 - out)

        def regularizer(params):
            # Just to check regularizer is called.
            del params
            functions_called.add('regularizer')
            return 0

        evaluator = hyp_cluster.HypClusterEvaluator(
            models.Model(init=None,
                         apply_for_eval=apply_for_eval,
                         apply_for_train=apply_for_train,
                         train_loss=train_loss,
                         eval_metrics={'accuracy': metrics.Accuracy()}),
            regularizer)

        cluster_params = [jnp.array(1.), jnp.array(-1.)]
        train_clients = [
            # Evaluated using cluster 0.
            (b'0',
             client_datasets.ClientDataset({
                 'x': np.array([3., 2., 1.]),
                 'y': np.array([1, 1, 0])
             }), jax.random.PRNGKey(0)),
            # Evaluated using cluster 1.
            (b'1',
             client_datasets.ClientDataset({
                 'x':
                 np.array([0.9, -0.9, 0.8, -0.8, -0.3]),
                 'y':
                 np.array([0, 1, 0, 1, 0])
             }), jax.random.PRNGKey(1)),
        ]
        # Test clients are generated from train_clients by swapping client ids and
        # then flipping labels.
        test_clients = [
            # Evaluated using cluster 0.
            (b'0',
             client_datasets.ClientDataset({
                 'x':
                 np.array([0.9, -0.9, 0.8, -0.8, -0.3]),
                 'y':
                 np.array([1, 0, 1, 0, 1])
             })),
            # Evaluated using cluster 1.
            (b'1',
             client_datasets.ClientDataset({
                 'x': np.array([3., 2., 1.]),
                 'y': np.array([0, 0, 1])
             })),
        ]
        for batch_size in [1, 2, 4]:
            with self.subTest(f'batch_size = {batch_size}'):
                batch_hparams = client_datasets.PaddedBatchHParams(
                    batch_size=batch_size)
                metric_values = dict(
                    evaluator.evaluate_clients(cluster_params=cluster_params,
                                               train_clients=train_clients,
                                               test_clients=test_clients,
                                               batch_hparams=batch_hparams))
                self.assertCountEqual(metric_values, [b'0', b'1'])
                self.assertCountEqual(metric_values[b'0'], ['accuracy'])
                npt.assert_allclose(metric_values[b'0']['accuracy'], 4 / 5)
                self.assertCountEqual(metric_values[b'1'], ['accuracy'])
                npt.assert_allclose(metric_values[b'1']['accuracy'], 2 / 3)
        self.assertCountEqual(
            functions_called,
            ['apply_for_train', 'train_loss', 'apply_for_eval', 'regularizer'])
예제 #11
0
    def test_maximization_step(self):
        # L1 distance from centers.
        def per_example_loss(params, batch, rng):
            # Randomly flip the center to test rng behavior.
            sign = jax.random.bernoulli(rng) * 2 - 1
            return jnp.abs(params * sign - batch['x'])

        def regularizer(params):
            return 0.01 * jnp.sum(jnp.abs(params))

        evaluator = models.AverageLossEvaluator(per_example_loss, regularizer)
        cluster_params = [jnp.array(0.), jnp.array(-1.), jnp.array(2.)]
        # Batch size is chosen so that we run 1 or 2 batches.
        batch_hparams = client_datasets.PaddedBatchHParams(batch_size=2)
        # Special seeds:
        # - No flip in first 2 steps for all 3 clusters: 0;
        # - Flip all in first 2 steps for all 3 clusters: 16;
        # - No flip then flip all for all 3 clusters: 68;
        # - Flips only cluster 1 in first 2 steps: 106.
        clients = [
            # No flip in first 2 steps for all 3 clusters.
            (b'near0', client_datasets.ClientDataset({'x': np.array([0.1])}),
             jax.random.PRNGKey(0)),
            # Flip all in first 2 steps for all 3 clusters.
            (b'near-1',
             client_datasets.ClientDataset({'x': np.array([0.9, 1.1, 1.3])}),
             jax.random.PRNGKey(16)),
            # No flip then flip all for all 3 clusters.
            (b'near2',
             client_datasets.ClientDataset({'x': np.array([1.9, 2.1, -2.1])}),
             jax.random.PRNGKey(68)),
            # Flips only cluster 1 in first 2 steps.
            (b'near1',
             client_datasets.ClientDataset({'x': np.array([0.9, 1.1, 1.3])}),
             jax.random.PRNGKey(106)),
        ]

        cluster_losses = hyp_cluster._cluster_losses(
            evaluator=evaluator,
            cluster_params=cluster_params,
            clients=clients,
            batch_hparams=batch_hparams)
        self.assertCountEqual(cluster_losses,
                              [b'near0', b'near-1', b'near2', b'near1'])
        npt.assert_allclose(cluster_losses[b'near0'],
                            np.array([0.1, 1.1 + 0.01, 1.9 + 0.02]))
        npt.assert_allclose(cluster_losses[b'near-1'],
                            np.array([1.1, 0.5 / 3 + 0.01, 3.1 + 0.02]))
        npt.assert_allclose(cluster_losses[b'near2'],
                            np.array([6.1 / 3, 9.1 / 3 + 0.01, 0.1 + 0.02]),
                            rtol=1e-6)
        npt.assert_allclose(cluster_losses[b'near1'],
                            np.array([1.1, 0.5 / 3 + 0.01, 0.9 + 0.02]))

        client_cluster_ids = hyp_cluster.maximization_step(
            evaluator=evaluator,
            cluster_params=cluster_params,
            clients=clients,
            batch_hparams=batch_hparams)
        self.assertDictEqual(client_cluster_ids, {
            b'near0': 0,
            b'near-1': 1,
            b'near2': 2,
            b'near1': 1
        })
예제 #12
0
    def test_kmeans_init(self):
        functions_called = set()

        def init(rng):
            functions_called.add('init')
            return jax.random.uniform(rng)

        def apply_for_train(params, batch, rng):
            functions_called.add('apply_for_train')
            self.assertIsNotNone(rng)
            return params - batch['x']

        def train_loss(batch, out):
            functions_called.add('train_loss')
            return jnp.square(out) + batch['bias']

        def regularizer(params):
            del params
            functions_called.add('regularizer')
            return 0

        initializer = hyp_cluster.ModelKMeansInitializer(
            models.Model(init=init,
                         apply_for_train=apply_for_train,
                         apply_for_eval=None,
                         train_loss=train_loss,
                         eval_metrics={}), optimizers.sgd(0.5), regularizer)
        # Each client has 1 example, so it's very easy to reach minimal loss, at
        # which point the loss entirely depends on bias.
        clients = [
            (b'0',
             client_datasets.ClientDataset({
                 'x': np.array([1.01]),
                 'bias': np.array([-2.])
             }), jax.random.PRNGKey(1)),
            (b'1',
             client_datasets.ClientDataset({
                 'x': np.array([3.02]),
                 'bias': np.array([-1.])
             }), jax.random.PRNGKey(2)),
            (b'2',
             client_datasets.ClientDataset({
                 'x': np.array([3.03]),
                 'bias': np.array([1.])
             }), jax.random.PRNGKey(3)),
            (b'3',
             client_datasets.ClientDataset({
                 'x': np.array([1.04]),
                 'bias': np.array([2.])
             }), jax.random.PRNGKey(3)),
        ]
        train_batch_hparams = client_datasets.ShuffleRepeatBatchHParams(
            batch_size=1, num_epochs=5)
        eval_batch_hparams = client_datasets.PaddedBatchHParams(batch_size=2)
        # Using a rng that leads to b'0' being the initial center.
        cluster_params = initializer.cluster_params(
            num_clusters=3,
            rng=jax.random.PRNGKey(0),
            clients=clients,
            train_batch_hparams=train_batch_hparams,
            eval_batch_hparams=eval_batch_hparams)
        self.assertIsInstance(cluster_params, list)
        self.assertLen(cluster_params, 3)
        npt.assert_allclose(cluster_params, [1.01, 3.03, 1.04])
        self.assertCountEqual(
            functions_called,
            ['init', 'apply_for_train', 'train_loss', 'regularizer'])
        # Using a rng that leads to b'2' being the initial center.
        cluster_params = initializer.cluster_params(
            num_clusters=3,
            rng=jax.random.PRNGKey(1),
            clients=clients,
            train_batch_hparams=train_batch_hparams,
            eval_batch_hparams=eval_batch_hparams)
        self.assertIsInstance(cluster_params, list)
        self.assertLen(cluster_params, 3)
        npt.assert_allclose(cluster_params, [3.03, 1.04, 1.04])
예제 #13
0
 def get(self):
     return client_datasets.PaddedBatchHParams(
         batch_size=self._get_flag('batch_size'))