def test_client_delta_clip_norm(self): base_optimizer = optimizers.sgd(learning_rate=1.0) train_batch_hparams = client_datasets.ShuffleRepeatBatchHParams( batch_size=2, num_epochs=1, seed=0) grad_batch_hparams = client_datasets.PaddedBatchHParams(batch_size=2) server_learning_rate = 0.2 algorithm = mime_lite.mime_lite(per_example_loss, base_optimizer, train_batch_hparams, grad_batch_hparams, server_learning_rate, client_delta_clip_norm=0.5) clients = [ (b'cid0', client_datasets.ClientDataset({'x': jnp.array([0.2, 0.4, 0.6])}), jax.random.PRNGKey(0)), (b'cid1', client_datasets.ClientDataset({'x': jnp.array([0.8, 0.1])}), jax.random.PRNGKey(1)), ] state = algorithm.init({'w': jnp.array(4.)}) state, client_diagnostics = algorithm.apply(state, clients) npt.assert_allclose(state.params['w'], 3.904) npt.assert_allclose( client_diagnostics[b'cid0']['clipped_delta_l2_norm'], 0.5) npt.assert_allclose( client_diagnostics[b'cid1']['clipped_delta_l2_norm'], 0.45000005)
def test_mime(self): base_optimizer = optimizers.sgd(learning_rate=1.0) train_batch_hparams = client_datasets.ShuffleRepeatBatchHParams( batch_size=2, num_epochs=1, seed=0) grad_batch_hparams = client_datasets.PaddedBatchHParams(batch_size=2) server_learning_rate = 0.2 algorithm = mime.mime(per_example_loss, base_optimizer, train_batch_hparams, grad_batch_hparams, server_learning_rate) with self.subTest('init'): state = algorithm.init({'w': jnp.array(4.)}) npt.assert_equal(state.params, {'w': jnp.array(4.)}) self.assertLen(state.opt_state, 2) with self.subTest('apply'): clients = [ (b'cid0', client_datasets.ClientDataset({'x': jnp.array([2., 4., 6.])}), jax.random.PRNGKey(0)), (b'cid1', client_datasets.ClientDataset({'x': jnp.array([8., 10.])}), jax.random.PRNGKey(1)), ] state, client_diagnostics = algorithm.apply(state, clients) npt.assert_allclose(state.params['w'], 2.08) npt.assert_allclose(client_diagnostics[b'cid0']['delta_l2_norm'], 12.) npt.assert_allclose(client_diagnostics[b'cid1']['delta_l2_norm'], 6.)
def test_model_train_clients_evaluation_fn(self): sampler = FakeClientSampler() sampler.set_round_num(1) clients = [sampler.sample()[0] for _ in range(4)] eval_fn = federated_experiment.ModelTrainClientsEvaluationFn( fake_model(), client_datasets.PaddedBatchHParams(batch_size=4)) state = FakeState() npt.assert_equal(eval_fn(state, 1, clients), {'accuracy': np.array(0.25)}) npt.assert_equal(eval_fn(state, 100, clients), {'accuracy': np.array(0.25)})
def test_model_full_evaluation_fn(self): sampler = FakeClientSampler() sampler.set_round_num(1) clients = [sampler.sample()[0] for _ in range(4)] fd = in_memory_federated_data.InMemoryFederatedData( dict((k, v.all_examples()) for k, v, _ in clients)) eval_fn = federated_experiment.ModelFullEvaluationFn( fd, fake_model(), client_datasets.PaddedBatchHParams(batch_size=4)) state = FakeState() npt.assert_equal(eval_fn(state, 1), {'accuracy': np.array(0.25)}) npt.assert_equal(eval_fn(state, 100), {'accuracy': np.array(0.25)})
def test_model_sample_clients_evaluation_fn(self): eval_fn = federated_experiment.ModelSampleClientsEvaluationFn( FakeClientSampler(), fake_model(), client_datasets.PaddedBatchHParams(batch_size=4)) state = FakeState() npt.assert_equal(eval_fn(state, 1), {'accuracy': np.array(1.)}) npt.assert_equal(eval_fn(state, 2), {'accuracy': np.array(0.)}) npt.assert_equal(eval_fn(state, 3), {'accuracy': np.array(0.)}) npt.assert_equal(eval_fn(state, 4), {'accuracy': np.array(0.)}) npt.assert_equal(eval_fn(state, 5), {'accuracy': np.array(0.)}) npt.assert_equal(eval_fn(state, 6), {'accuracy': np.array(1.)})
def test_get(self): with self.subTest('default'): self.assertEqual( self.SHUFFLE_BATCH_REPEAT.get(), client_datasets.ShuffleRepeatBatchHParams(batch_size=128)) self.assertEqual( self.PADDED_BATCH.get(), client_datasets.PaddedBatchHParams(batch_size=128)) self.assertEqual(self.BATCH.get(), client_datasets.BatchHParams(batch_size=128)) with self.subTest('custom'): with flagsaver.flagsaver(shuffle_batch_repeat_batch_size=12, shuffle_batch_repeat_num_epochs=21, padded_batch_size=34, batch_size=56): self.assertEqual( self.SHUFFLE_BATCH_REPEAT.get(), client_datasets.ShuffleRepeatBatchHParams(batch_size=12, num_epochs=21)) self.assertEqual( self.PADDED_BATCH.get(), client_datasets.PaddedBatchHParams(batch_size=34)) self.assertEqual(self.BATCH.get(), client_datasets.BatchHParams(batch_size=56))
def test_padded_batch(self): d = client_datasets.ClientDataset( { 'a': np.arange(5), 'b': np.arange(10).reshape([5, 2]) }, client_datasets.BatchPreprocessor([lambda x: { **x, 'a': 2 * x['a'] }])) with self.subTest('1 bucket, kwargs'): view = d.padded_batch(batch_size=3) # `view` should be repeatedly iterable. for _ in range(2): batches = list(view) self.assertLen(batches, 2) npt.assert_equal( batches[0], { 'a': [0, 2, 4], 'b': [[0, 1], [2, 3], [4, 5]], '__mask__': [True, True, True], }) npt.assert_equal( batches[1], { 'a': [6, 8, 0], 'b': [[6, 7], [8, 9], [0, 0]], '__mask__': [True, True, False] }) with self.subTest('2 buckets, kwargs override'): view = d.padded_batch( client_datasets.PaddedBatchHParams(batch_size=4), num_batch_size_buckets=2) # `view` should be repeatedly iterable. for _ in range(2): batches = list(view) self.assertLen(batches, 2) npt.assert_equal( batches[0], { 'a': [0, 2, 4, 6], 'b': [[0, 1], [2, 3], [4, 5], [6, 7]], '__mask__': [True, True, True, True], }) npt.assert_equal(batches[1], { 'a': [8, 0], 'b': [[8, 9], [0, 0]], '__mask__': [True, False] })
def test_agnostic_federated_averaging(self): algorithm = agnostic_fed_avg.agnostic_federated_averaging( per_example_loss=per_example_loss, client_optimizer=optimizers.sgd(learning_rate=1.0), server_optimizer=optimizers.sgd(learning_rate=0.1), client_batch_hparams=client_datasets.ShuffleRepeatBatchHParams( batch_size=3, num_epochs=1, seed=0), domain_batch_hparams=client_datasets.PaddedBatchHParams( batch_size=3), init_domain_weights=[0.1, 0.2, 0.3, 0.4], domain_learning_rate=0.01, domain_algorithm='eg', domain_window_size=2, init_domain_window=[1., 2., 3., 4.]) with self.subTest('init'): state = algorithm.init({'w': jnp.array(4.)}) npt.assert_equal(state.params, {'w': jnp.array(4.)}) self.assertLen(state.opt_state, 2) npt.assert_allclose(state.domain_weights, [0.1, 0.2, 0.3, 0.4]) npt.assert_allclose(state.domain_window, [[1., 2., 3., 4.], [1., 2., 3., 4.]]) with self.subTest('apply'): clients = [ (b'cid0', client_datasets.ClientDataset({ 'x': jnp.array([1., 2., 4., 3., 6., 1.]), 'domain_id': jnp.array([1, 0, 0, 0, 2, 2]) }), jax.random.PRNGKey(0)), (b'cid1', client_datasets.ClientDataset({ 'x': jnp.array([8., 10., 5.]), 'domain_id': jnp.array([1, 3, 1]) }), jax.random.PRNGKey(1)), ] next_state, client_diagnostics = algorithm.apply(state, clients) npt.assert_allclose(next_state.params['w'], 3.5555556) npt.assert_allclose( next_state.domain_weights, [0.08702461, 0.18604803, 0.2663479, 0.46057943]) npt.assert_allclose(next_state.domain_window, [[1., 2., 3., 4.], [3., 3., 2., 1.]]) npt.assert_allclose(client_diagnostics[b'cid0']['delta_l2_norm'], 2.8333335) npt.assert_allclose(client_diagnostics[b'cid1']['delta_l2_norm'], 7.666667) with self.subTest('invalid init_domain_weights'): with self.assertRaisesRegex( ValueError, 'init_domain_weights must sum to approximately 1.'): agnostic_fed_avg.agnostic_federated_averaging( per_example_loss=per_example_loss, client_optimizer=optimizers.sgd(learning_rate=1.0), server_optimizer=optimizers.sgd(learning_rate=1.0), client_batch_hparams=client_datasets. ShuffleRepeatBatchHParams(batch_size=3), domain_batch_hparams=client_datasets.PaddedBatchHParams( batch_size=3), init_domain_weights=[50., 0., 0., 0.], domain_learning_rate=0.5) with self.subTest('unequal lengths'): with self.assertRaisesRegex( ValueError, 'init_domain_weights and init_domain_window must be equal lengths.' ): agnostic_fed_avg.agnostic_federated_averaging( per_example_loss=per_example_loss, client_optimizer=optimizers.sgd(learning_rate=1.0), server_optimizer=optimizers.sgd(learning_rate=1.0), client_batch_hparams=client_datasets. ShuffleRepeatBatchHParams(batch_size=3), domain_batch_hparams=client_datasets.PaddedBatchHParams( batch_size=3), init_domain_weights=[0.1, 0.2, 0.3, 0.4], domain_learning_rate=0.5, init_domain_window=[1, 2])
def test_hyp_cluster(self): functions_called = set() def per_example_loss(params, batch, rng): self.assertIsNotNone(rng) functions_called.add('per_example_loss') return jnp.square(params - batch['x']) def regularizer(params): del params functions_called.add('regularizer') return 0 client_optimizer = optimizers.sgd(0.5) server_optimizer = optimizers.sgd(0.25) maximization_batch_hparams = client_datasets.PaddedBatchHParams( batch_size=2) expectation_batch_hparams = client_datasets.ShuffleRepeatBatchHParams( batch_size=1, num_epochs=5) algorithm = hyp_cluster.hyp_cluster( per_example_loss=per_example_loss, client_optimizer=client_optimizer, server_optimizer=server_optimizer, maximization_batch_hparams=maximization_batch_hparams, expectation_batch_hparams=expectation_batch_hparams, regularizer=regularizer) init_state = algorithm.init([jnp.array(1.), jnp.array(-1.)]) # Nothing happens with empty data. no_op_state, diagnostics = algorithm.apply(init_state, clients=[]) npt.assert_array_equal(init_state.cluster_params, no_op_state.cluster_params) self.assertEmpty(diagnostics) # Some actual training. PRNGKeys are not actually used. clients = [ (b'0', client_datasets.ClientDataset({'x': np.array([1.1])}), jax.random.PRNGKey(0)), (b'1', client_datasets.ClientDataset({'x': np.array([0.9, 0.9])}), jax.random.PRNGKey(1)), (b'2', client_datasets.ClientDataset({'x': np.array([-1.1])}), jax.random.PRNGKey(2)), (b'3', client_datasets.ClientDataset({'x': np.array([-0.9, -0.9, -0.9])}), jax.random.PRNGKey(3)), ] next_state, diagnostics = algorithm.apply(init_state, clients) npt.assert_equal( diagnostics, { b'0': { 'cluster_id': 0 }, b'1': { 'cluster_id': 0 }, b'2': { 'cluster_id': 1 }, b'3': { 'cluster_id': 1 }, }) cluster_params = next_state.cluster_params self.assertIsInstance(cluster_params, list) self.assertLen(cluster_params, 2) npt.assert_allclose(cluster_params[0], [1. - 0.25 * 0.1 / 3]) npt.assert_allclose(cluster_params[1], [-1. + 0.25 * 0.2 / 4]) self.assertCountEqual(functions_called, ['per_example_loss', 'regularizer'])
def test_hyp_cluster_evaluator(self): functions_called = set() def apply_for_eval(params, batch): functions_called.add('apply_for_eval') score = params * batch['x'] return jnp.stack([-score, score], axis=-1) def apply_for_train(params, batch, rng): functions_called.add('apply_for_train') self.assertIsNotNone(rng) return params * batch['x'] def train_loss(batch, out): functions_called.add('train_loss') return jnp.abs(batch['y'] * 2 - 1 - out) def regularizer(params): # Just to check regularizer is called. del params functions_called.add('regularizer') return 0 evaluator = hyp_cluster.HypClusterEvaluator( models.Model(init=None, apply_for_eval=apply_for_eval, apply_for_train=apply_for_train, train_loss=train_loss, eval_metrics={'accuracy': metrics.Accuracy()}), regularizer) cluster_params = [jnp.array(1.), jnp.array(-1.)] train_clients = [ # Evaluated using cluster 0. (b'0', client_datasets.ClientDataset({ 'x': np.array([3., 2., 1.]), 'y': np.array([1, 1, 0]) }), jax.random.PRNGKey(0)), # Evaluated using cluster 1. (b'1', client_datasets.ClientDataset({ 'x': np.array([0.9, -0.9, 0.8, -0.8, -0.3]), 'y': np.array([0, 1, 0, 1, 0]) }), jax.random.PRNGKey(1)), ] # Test clients are generated from train_clients by swapping client ids and # then flipping labels. test_clients = [ # Evaluated using cluster 0. (b'0', client_datasets.ClientDataset({ 'x': np.array([0.9, -0.9, 0.8, -0.8, -0.3]), 'y': np.array([1, 0, 1, 0, 1]) })), # Evaluated using cluster 1. (b'1', client_datasets.ClientDataset({ 'x': np.array([3., 2., 1.]), 'y': np.array([0, 0, 1]) })), ] for batch_size in [1, 2, 4]: with self.subTest(f'batch_size = {batch_size}'): batch_hparams = client_datasets.PaddedBatchHParams( batch_size=batch_size) metric_values = dict( evaluator.evaluate_clients(cluster_params=cluster_params, train_clients=train_clients, test_clients=test_clients, batch_hparams=batch_hparams)) self.assertCountEqual(metric_values, [b'0', b'1']) self.assertCountEqual(metric_values[b'0'], ['accuracy']) npt.assert_allclose(metric_values[b'0']['accuracy'], 4 / 5) self.assertCountEqual(metric_values[b'1'], ['accuracy']) npt.assert_allclose(metric_values[b'1']['accuracy'], 2 / 3) self.assertCountEqual( functions_called, ['apply_for_train', 'train_loss', 'apply_for_eval', 'regularizer'])
def test_maximization_step(self): # L1 distance from centers. def per_example_loss(params, batch, rng): # Randomly flip the center to test rng behavior. sign = jax.random.bernoulli(rng) * 2 - 1 return jnp.abs(params * sign - batch['x']) def regularizer(params): return 0.01 * jnp.sum(jnp.abs(params)) evaluator = models.AverageLossEvaluator(per_example_loss, regularizer) cluster_params = [jnp.array(0.), jnp.array(-1.), jnp.array(2.)] # Batch size is chosen so that we run 1 or 2 batches. batch_hparams = client_datasets.PaddedBatchHParams(batch_size=2) # Special seeds: # - No flip in first 2 steps for all 3 clusters: 0; # - Flip all in first 2 steps for all 3 clusters: 16; # - No flip then flip all for all 3 clusters: 68; # - Flips only cluster 1 in first 2 steps: 106. clients = [ # No flip in first 2 steps for all 3 clusters. (b'near0', client_datasets.ClientDataset({'x': np.array([0.1])}), jax.random.PRNGKey(0)), # Flip all in first 2 steps for all 3 clusters. (b'near-1', client_datasets.ClientDataset({'x': np.array([0.9, 1.1, 1.3])}), jax.random.PRNGKey(16)), # No flip then flip all for all 3 clusters. (b'near2', client_datasets.ClientDataset({'x': np.array([1.9, 2.1, -2.1])}), jax.random.PRNGKey(68)), # Flips only cluster 1 in first 2 steps. (b'near1', client_datasets.ClientDataset({'x': np.array([0.9, 1.1, 1.3])}), jax.random.PRNGKey(106)), ] cluster_losses = hyp_cluster._cluster_losses( evaluator=evaluator, cluster_params=cluster_params, clients=clients, batch_hparams=batch_hparams) self.assertCountEqual(cluster_losses, [b'near0', b'near-1', b'near2', b'near1']) npt.assert_allclose(cluster_losses[b'near0'], np.array([0.1, 1.1 + 0.01, 1.9 + 0.02])) npt.assert_allclose(cluster_losses[b'near-1'], np.array([1.1, 0.5 / 3 + 0.01, 3.1 + 0.02])) npt.assert_allclose(cluster_losses[b'near2'], np.array([6.1 / 3, 9.1 / 3 + 0.01, 0.1 + 0.02]), rtol=1e-6) npt.assert_allclose(cluster_losses[b'near1'], np.array([1.1, 0.5 / 3 + 0.01, 0.9 + 0.02])) client_cluster_ids = hyp_cluster.maximization_step( evaluator=evaluator, cluster_params=cluster_params, clients=clients, batch_hparams=batch_hparams) self.assertDictEqual(client_cluster_ids, { b'near0': 0, b'near-1': 1, b'near2': 2, b'near1': 1 })
def test_kmeans_init(self): functions_called = set() def init(rng): functions_called.add('init') return jax.random.uniform(rng) def apply_for_train(params, batch, rng): functions_called.add('apply_for_train') self.assertIsNotNone(rng) return params - batch['x'] def train_loss(batch, out): functions_called.add('train_loss') return jnp.square(out) + batch['bias'] def regularizer(params): del params functions_called.add('regularizer') return 0 initializer = hyp_cluster.ModelKMeansInitializer( models.Model(init=init, apply_for_train=apply_for_train, apply_for_eval=None, train_loss=train_loss, eval_metrics={}), optimizers.sgd(0.5), regularizer) # Each client has 1 example, so it's very easy to reach minimal loss, at # which point the loss entirely depends on bias. clients = [ (b'0', client_datasets.ClientDataset({ 'x': np.array([1.01]), 'bias': np.array([-2.]) }), jax.random.PRNGKey(1)), (b'1', client_datasets.ClientDataset({ 'x': np.array([3.02]), 'bias': np.array([-1.]) }), jax.random.PRNGKey(2)), (b'2', client_datasets.ClientDataset({ 'x': np.array([3.03]), 'bias': np.array([1.]) }), jax.random.PRNGKey(3)), (b'3', client_datasets.ClientDataset({ 'x': np.array([1.04]), 'bias': np.array([2.]) }), jax.random.PRNGKey(3)), ] train_batch_hparams = client_datasets.ShuffleRepeatBatchHParams( batch_size=1, num_epochs=5) eval_batch_hparams = client_datasets.PaddedBatchHParams(batch_size=2) # Using a rng that leads to b'0' being the initial center. cluster_params = initializer.cluster_params( num_clusters=3, rng=jax.random.PRNGKey(0), clients=clients, train_batch_hparams=train_batch_hparams, eval_batch_hparams=eval_batch_hparams) self.assertIsInstance(cluster_params, list) self.assertLen(cluster_params, 3) npt.assert_allclose(cluster_params, [1.01, 3.03, 1.04]) self.assertCountEqual( functions_called, ['init', 'apply_for_train', 'train_loss', 'regularizer']) # Using a rng that leads to b'2' being the initial center. cluster_params = initializer.cluster_params( num_clusters=3, rng=jax.random.PRNGKey(1), clients=clients, train_batch_hparams=train_batch_hparams, eval_batch_hparams=eval_batch_hparams) self.assertIsInstance(cluster_params, list) self.assertLen(cluster_params, 3) npt.assert_allclose(cluster_params, [3.03, 1.04, 1.04])
def get(self): return client_datasets.PaddedBatchHParams( batch_size=self._get_flag('batch_size'))