def test_federated_averaging(self): client_optimizer = optimizers.sgd(learning_rate=1.0) server_optimizer = optimizers.sgd(learning_rate=1.0) client_batch_hparams = client_datasets.ShuffleRepeatBatchHParams( batch_size=2, num_epochs=1, seed=0) algorithm = fed_avg.federated_averaging(grad_fn, client_optimizer, server_optimizer, client_batch_hparams) with self.subTest('init'): state = algorithm.init({'w': jnp.array([0., 2., 4.])}) npt.assert_array_equal(state.params['w'], [0., 2., 4.]) self.assertLen(state.opt_state, 2) with self.subTest('apply'): clients = [ (b'cid0', client_datasets.ClientDataset({'x': jnp.array([2., 4., 6.])}), jax.random.PRNGKey(0)), (b'cid1', client_datasets.ClientDataset({'x': jnp.array([8., 10.])}), jax.random.PRNGKey(1)), ] state, client_diagnostics = algorithm.apply(state, clients) npt.assert_allclose(state.params['w'], [0., 1.5655555, 3.131111]) npt.assert_allclose(client_diagnostics[b'cid0']['delta_l2_norm'], 1.4534444262) npt.assert_allclose(client_diagnostics[b'cid1']['delta_l2_norm'], 0.2484521282)
def test_fed_prox(self): client_optimizer = optimizers.sgd(learning_rate=1.0) server_optimizer = optimizers.sgd(learning_rate=1.0) client_batch_hparams = client_datasets.ShuffleRepeatBatchHParams( batch_size=2, num_epochs=1, seed=0) algorithm = fed_prox.fed_prox(per_example_loss, client_optimizer, server_optimizer, client_batch_hparams, proximal_weight=0.01) with self.subTest('init'): state = algorithm.init({'w': jnp.array(4.)}) npt.assert_array_equal(state.params['w'], 4.) self.assertLen(state.opt_state, 2) with self.subTest('apply'): clients = [ (b'cid0', client_datasets.ClientDataset({'x': jnp.array([2., 4., 6.])}), jax.random.PRNGKey(0)), (b'cid1', client_datasets.ClientDataset({'x': jnp.array([8., 10.])}), jax.random.PRNGKey(1)), ] state, client_diagnostics = algorithm.apply(state, clients) npt.assert_allclose(state.params['w'], -3.77) npt.assert_allclose(client_diagnostics[b'cid0']['delta_l2_norm'], 6.95) npt.assert_allclose(client_diagnostics[b'cid1']['delta_l2_norm'], 9.)
def test_multi(self): batches = list( client_datasets.padded_batch_client_datasets([ client_datasets.ClientDataset({'x': np.arange(10)}), client_datasets.ClientDataset({'x': np.arange(10, 11)}), client_datasets.ClientDataset({'x': np.arange(11, 15)}), client_datasets.ClientDataset({'x': np.arange(15, 17)}) ], batch_size=4)) self.assertLen(batches, 5) npt.assert_equal(batches[0], { 'x': [0, 1, 2, 3], '__mask__': [True, True, True, True] }) npt.assert_equal(batches[1], { 'x': [4, 5, 6, 7], '__mask__': [True, True, True, True] }) npt.assert_equal(batches[2], { 'x': [8, 9, 10, 11], '__mask__': [True, True, True, True] }) npt.assert_equal(batches[3], { 'x': [12, 13, 14, 15], '__mask__': [True, True, True, True] }) npt.assert_equal(batches[4], { 'x': [16, 0, 0, 0], '__mask__': [True, False, False, False] })
def test_multi(self): batches = list( client_datasets.buffered_shuffle_batch_client_datasets( [ client_datasets.ClientDataset({'x': np.arange(10)}), client_datasets.ClientDataset({'x': np.arange(10, 11)}), client_datasets.ClientDataset({'x': np.arange(11, 15)}), client_datasets.ClientDataset({'x': np.arange(15, 17)}) ], batch_size=4, buffer_size=16, rng=np.random.RandomState(0))) self.assertLen(batches, 5) npt.assert_equal(batches[0], { 'x': [1, 6, 16, 8], }) npt.assert_equal(batches[1], { 'x': [9, 13, 4, 2], }) npt.assert_equal(batches[2], { 'x': [14, 10, 7, 15], }) npt.assert_equal(batches[3], { 'x': [11, 3, 0, 5], }) npt.assert_equal(batches[4], { 'x': [12], })
def test_client_delta_clip_norm(self): base_optimizer = optimizers.sgd(learning_rate=1.0) train_batch_hparams = client_datasets.ShuffleRepeatBatchHParams( batch_size=2, num_epochs=1, seed=0) grad_batch_hparams = client_datasets.PaddedBatchHParams(batch_size=2) server_learning_rate = 0.2 algorithm = mime_lite.mime_lite(per_example_loss, base_optimizer, train_batch_hparams, grad_batch_hparams, server_learning_rate, client_delta_clip_norm=0.5) clients = [ (b'cid0', client_datasets.ClientDataset({'x': jnp.array([0.2, 0.4, 0.6])}), jax.random.PRNGKey(0)), (b'cid1', client_datasets.ClientDataset({'x': jnp.array([0.8, 0.1])}), jax.random.PRNGKey(1)), ] state = algorithm.init({'w': jnp.array(4.)}) state, client_diagnostics = algorithm.apply(state, clients) npt.assert_allclose(state.params['w'], 3.904) npt.assert_allclose( client_diagnostics[b'cid0']['clipped_delta_l2_norm'], 0.5) npt.assert_allclose( client_diagnostics[b'cid1']['clipped_delta_l2_norm'], 0.45000005)
def test_mime(self): base_optimizer = optimizers.sgd(learning_rate=1.0) train_batch_hparams = client_datasets.ShuffleRepeatBatchHParams( batch_size=2, num_epochs=1, seed=0) grad_batch_hparams = client_datasets.PaddedBatchHParams(batch_size=2) server_learning_rate = 0.2 algorithm = mime.mime(per_example_loss, base_optimizer, train_batch_hparams, grad_batch_hparams, server_learning_rate) with self.subTest('init'): state = algorithm.init({'w': jnp.array(4.)}) npt.assert_equal(state.params, {'w': jnp.array(4.)}) self.assertLen(state.opt_state, 2) with self.subTest('apply'): clients = [ (b'cid0', client_datasets.ClientDataset({'x': jnp.array([2., 4., 6.])}), jax.random.PRNGKey(0)), (b'cid1', client_datasets.ClientDataset({'x': jnp.array([8., 10.])}), jax.random.PRNGKey(1)), ] state, client_diagnostics = algorithm.apply(state, clients) npt.assert_allclose(state.params['w'], 2.08) npt.assert_allclose(client_diagnostics[b'cid0']['delta_l2_norm'], 12.) npt.assert_allclose(client_diagnostics[b'cid1']['delta_l2_norm'], 6.)
def test_different_features(self): with self.assertRaisesRegex( ValueError, 'client_datasets should have identical features'): list( client_datasets.padded_batch_client_datasets([ client_datasets.ClientDataset({'x': np.arange(10)}), client_datasets.ClientDataset({'y': np.arange(10, 11)}) ], batch_size=4))
def test_different_preprocessors(self): with self.assertRaisesRegex( ValueError, 'client_datasets should have the identical Preprocessor object'): list( client_datasets.padded_batch_client_datasets([ client_datasets.ClientDataset( {'x': np.arange(10)}, client_datasets.BatchPreprocessor()), client_datasets.ClientDataset({'x': np.arange(10, 11)}, client_datasets.BatchPreprocessor()) ], batch_size=4))
def test_different_features(self): with self.assertRaisesRegex( ValueError, 'client_datasets should have identical features'): list( client_datasets.buffered_shuffle_batch_client_datasets( [ client_datasets.ClientDataset({'x': np.arange(10)}), client_datasets.ClientDataset({'y': np.arange(10, 11)}) ], batch_size=4, buffer_size=16, rng=np.random.RandomState(0)))
def test_all_examples(self): raw_examples = {'a': np.arange(3), 'b': np.arange(6).reshape([3, 2])} with self.subTest('no preprocessing'): npt.assert_equal( client_datasets.ClientDataset(raw_examples).all_examples(), raw_examples) with self.subTest('with preprocessing'): npt.assert_equal( client_datasets.ClientDataset( raw_examples, client_datasets.BatchPreprocessor([lambda x: { 'c': x['a'] + 1 }])).all_examples(), {'c': [1, 2, 3]})
def test_different_preprocessors(self): with self.assertRaisesRegex( ValueError, 'client_datasets should have the identical Preprocessor object'): list( client_datasets.buffered_shuffle_batch_client_datasets( [ client_datasets.ClientDataset( {'x': np.arange(10, 20)}, client_datasets.BatchPreprocessor()), client_datasets.ClientDataset( {'x': np.arange(20, 30)}, client_datasets.BatchPreprocessor()) ], batch_size=4, buffer_size=16, rng=np.random.RandomState(0)))
def sample(self): client_id = self._round_num + self._base dataset = client_datasets.ClientDataset({ 'x': np.array([client_id], dtype=np.int32), 'y': np.array([client_id], dtype=np.int32) % 2 }) rng = None self._round_num += 1 return [(client_id, dataset, rng)]
def test_single_buffer_1(self): batches = list( client_datasets.buffered_shuffle_batch_client_datasets( [client_datasets.ClientDataset({'x': np.arange(6)})], batch_size=5, buffer_size=1, rng=np.random.RandomState(0))) self.assertLen(batches, 2) # No shuffling. npt.assert_equal(batches[0], {'x': np.arange(5)}) npt.assert_equal(batches[1], {'x': [5]})
def test_uniform_get_client_sampler(self): num_clients = 2 round_num = 3 client_to_data_mapping = {i: {'x': np.arange(i)} for i in range(100)} fd = in_memory_federated_data.InMemoryFederatedData( client_to_data_mapping) client_sampler = client_samplers.UniformGetClientSampler( fd, num_clients, seed=0, start_round_num=round_num) with self.subTest('sample'): client_rngs = jax.random.split(jax.random.PRNGKey(round_num), num_clients) expect = [(78, client_datasets.ClientDataset({'x': np.arange(78)}), client_rngs[0]), (56, client_datasets.ClientDataset({'x': np.arange(56)}), client_rngs[1])] self.assert_clients_equal(client_sampler.sample(), expect) with self.subTest('set_round_num'): self.assertNotEqual(client_sampler._round_num, round_num) client_sampler.set_round_num(round_num) self.assert_clients_equal(client_sampler.sample(), expect)
def __next__( self) -> Tuple[federated_data.ClientId, client_datasets.ClientDataset]: client_ids_result = self._client_ids_cursor.fetchone() if client_ids_result is None: raise StopIteration client_id = client_ids_result[0] examples_cursor = self._connection.execute( 'SELECT serialized_example_proto FROM examples WHERE split_name = ? AND client_id = ? ORDER BY rowid;', [self._split_name, client_id]) examples = [r[0] for r in examples_cursor.fetchall()] return client_id, client_datasets.ClientDataset( self._parse_examples(examples))
def test_expectation_step(self): def per_example_loss(params, batch, rng): self.assertIsNotNone(rng) return jnp.square(params - batch['x']) trainer = hyp_cluster.ClientDeltaTrainer(models.grad(per_example_loss), optimizers.sgd(0.5)) batch_hparams = client_datasets.ShuffleRepeatBatchHParams(batch_size=1, num_epochs=5) cluster_params = [jnp.array(1.), jnp.array(-1.), jnp.array(3.14)] client_cluster_ids = {b'0': 0, b'1': 0, b'2': 1, b'3': 1, b'4': 0} # RNGs are not actually used. clients = [ (b'0', client_datasets.ClientDataset({'x': np.array([1.1])}), jax.random.PRNGKey(0)), (b'1', client_datasets.ClientDataset({'x': np.array([0.9, 0.9])}), jax.random.PRNGKey(1)), (b'2', client_datasets.ClientDataset({'x': np.array([-1.1])}), jax.random.PRNGKey(2)), (b'3', client_datasets.ClientDataset({'x': np.array([-0.9, -0.9, -0.9])}), jax.random.PRNGKey(3)), (b'4', client_datasets.ClientDataset({'x': np.array([-0.1])}), jax.random.PRNGKey(4)), ] cluster_delta_params = hyp_cluster.expectation_step( trainer=trainer, cluster_params=cluster_params, client_cluster_ids=client_cluster_ids, clients=clients, batch_hparams=batch_hparams) self.assertIsInstance(cluster_delta_params, list) self.assertLen(cluster_delta_params, 3) npt.assert_allclose(cluster_delta_params[0], (-0.1 + 0.1 * 2 + 1.1) / 4) npt.assert_allclose(cluster_delta_params[1], (0.1 - 0.1 * 3) / 4, rtol=1e-6) self.assertIsNone(cluster_delta_params[2])
def test_single_default_buckets(self): batches = list( client_datasets.padded_batch_client_datasets( [client_datasets.ClientDataset({'x': np.arange(6)})], batch_size=5)) self.assertLen(batches, 2) npt.assert_equal(batches[0], { 'x': np.arange(5), '__mask__': [True, True, True, True, True] }) npt.assert_equal(batches[1], { 'x': [5, 0, 0, 0, 0], '__mask__': [True, False, False, False, False] })
def test_single_buffer_4(self): batches = list( client_datasets.buffered_shuffle_batch_client_datasets( [client_datasets.ClientDataset({'x': np.arange(8)})], batch_size=6, buffer_size=4, rng=np.random.RandomState(0))) self.assertLen(batches, 2) npt.assert_equal(batches[0], { 'x': [2, 4, 5, 6, 7, 3], }) npt.assert_equal(batches[1], { 'x': [1, 0], })
def test_single_has_buckets(self): batches = list( client_datasets.padded_batch_client_datasets( [client_datasets.ClientDataset({'x': np.arange(8)})], batch_size=6, num_batch_size_buckets=4)) self.assertLen(batches, 2) npt.assert_equal(batches[0], { 'x': np.arange(6), '__mask__': [True, True, True, True, True, True] }) npt.assert_equal(batches[1], { 'x': [6, 7, 0], '__mask__': [True, True, False] })
def bench_client_dataset(preprocess, mode, batch_size=128, num_steps=100): """Benchmarks ClientDataset.""" preprocessor = client_datasets.NoOpBatchPreprocessor if preprocess: preprocessor = preprocessor.append(f) dataset = client_datasets.ClientDataset(FAKE_MNIST, preprocessor) if mode == 'train': batches = dataset.shuffle_repeat_batch(batch_size=batch_size, num_steps=num_steps) else: batches = dataset.padded_batch(batch_size=batch_size, num_batch_size_buckets=4) n = 0 for _ in batches: n += 1 return n
def test_padded_batch(self): d = client_datasets.ClientDataset( { 'a': np.arange(5), 'b': np.arange(10).reshape([5, 2]) }, client_datasets.BatchPreprocessor([lambda x: { **x, 'a': 2 * x['a'] }])) with self.subTest('1 bucket, kwargs'): view = d.padded_batch(batch_size=3) # `view` should be repeatedly iterable. for _ in range(2): batches = list(view) self.assertLen(batches, 2) npt.assert_equal( batches[0], { 'a': [0, 2, 4], 'b': [[0, 1], [2, 3], [4, 5]], '__mask__': [True, True, True], }) npt.assert_equal( batches[1], { 'a': [6, 8, 0], 'b': [[6, 7], [8, 9], [0, 0]], '__mask__': [True, True, False] }) with self.subTest('2 buckets, kwargs override'): view = d.padded_batch( client_datasets.PaddedBatchHParams(batch_size=4), num_batch_size_buckets=2) # `view` should be repeatedly iterable. for _ in range(2): batches = list(view) self.assertLen(batches, 2) npt.assert_equal( batches[0], { 'a': [0, 2, 4, 6], 'b': [[0, 1], [2, 3], [4, 5], [6, 7]], '__mask__': [True, True, True, True], }) npt.assert_equal(batches[1], { 'a': [8, 0], 'b': [[8, 9], [0, 0]], '__mask__': [True, False] })
def test_slice(self): d = client_datasets.ClientDataset( { 'a': np.arange(5), 'b': np.arange(10).reshape([5, 2]) }, client_datasets.BatchPreprocessor([lambda x: { **x, 'a': 2 * x['a'] }])) with self.subTest('slice [:3]'): sliced = d[:3] batch = next(iter(sliced.batch(batch_size=3))) npt.assert_equal(batch, {'a': [0, 2, 4], 'b': [[0, 1], [2, 3], [4, 5]]}) with self.subTest('slice [-3:]'): sliced = d[-3:] batch = next(iter(sliced.batch(batch_size=3))) npt.assert_equal(batch, {'a': [4, 6, 8], 'b': [[4, 5], [6, 7], [8, 9]]})
def test_preprocessor(self): batches = list( client_datasets.padded_batch_client_datasets([ client_datasets.ClientDataset({'x': np.arange(6)}, client_datasets.BatchPreprocessor( [lambda x: { 'x': x['x'] + 1 }])) ], batch_size=5)) self.assertLen(batches, 2) npt.assert_equal(batches[0], { 'x': np.arange(5) + 1, '__mask__': [True, True, True, True, True] }) npt.assert_equal(batches[1], { 'x': [6, 0, 0, 0, 0], '__mask__': [True, False, False, False, False] })
def test_preprocessor(self): batches = list( client_datasets.buffered_shuffle_batch_client_datasets( [ client_datasets.ClientDataset({'x': np.arange(6)}, client_datasets.BatchPreprocessor( [lambda x: { 'x': x['x'] + 1 }])) ], batch_size=5, buffer_size=16, rng=np.random.RandomState(0))) self.assertLen(batches, 2) npt.assert_equal(batches[0], { 'x': [6, 3, 2, 4, 1], }) npt.assert_equal(batches[1], { 'x': [5], })
def test_batch(self): d = client_datasets.ClientDataset( { 'a': np.arange(5), 'b': np.arange(10).reshape([5, 2]) }, client_datasets.BatchPreprocessor([lambda x: { **x, 'a': 2 * x['a'] }])) with self.subTest('keep remainder, kwargs'): view = d.batch(batch_size=3) # `view` should be repeatedly iterable. for _ in range(2): batches = list(view) self.assertLen(batches, 2) npt.assert_equal(batches[0], { 'a': [0, 2, 4], 'b': [[0, 1], [2, 3], [4, 5]] }) npt.assert_equal(batches[1], {'a': [6, 8], 'b': [[6, 7], [8, 9]]}) with self.subTest('drop remainder, hparams'): view = d.batch( client_datasets.BatchHParams(batch_size=3, drop_remainder=True)) # `view` should be repeatedly iterable. for _ in range(2): batches = list(view) self.assertLen(batches, 1) npt.assert_equal(batches[0], { 'a': [0, 2, 4], 'b': [[0, 1], [2, 3], [4, 5]] }) with self.subTest('no op drop remainder, hparams and kwargs'): view = d.batch( client_datasets.BatchHParams(batch_size=5), drop_remainder=True) # `view` should be repeatedly iterable. for _ in range(2): batches = list(view) self.assertLen(batches, 1) npt.assert_equal(batches[0], { 'a': [0, 2, 4, 6, 8], 'b': [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]] })
def test_agnostic_federated_averaging(self): algorithm = agnostic_fed_avg.agnostic_federated_averaging( per_example_loss=per_example_loss, client_optimizer=optimizers.sgd(learning_rate=1.0), server_optimizer=optimizers.sgd(learning_rate=0.1), client_batch_hparams=client_datasets.ShuffleRepeatBatchHParams( batch_size=3, num_epochs=1, seed=0), domain_batch_hparams=client_datasets.PaddedBatchHParams( batch_size=3), init_domain_weights=[0.1, 0.2, 0.3, 0.4], domain_learning_rate=0.01, domain_algorithm='eg', domain_window_size=2, init_domain_window=[1., 2., 3., 4.]) with self.subTest('init'): state = algorithm.init({'w': jnp.array(4.)}) npt.assert_equal(state.params, {'w': jnp.array(4.)}) self.assertLen(state.opt_state, 2) npt.assert_allclose(state.domain_weights, [0.1, 0.2, 0.3, 0.4]) npt.assert_allclose(state.domain_window, [[1., 2., 3., 4.], [1., 2., 3., 4.]]) with self.subTest('apply'): clients = [ (b'cid0', client_datasets.ClientDataset({ 'x': jnp.array([1., 2., 4., 3., 6., 1.]), 'domain_id': jnp.array([1, 0, 0, 0, 2, 2]) }), jax.random.PRNGKey(0)), (b'cid1', client_datasets.ClientDataset({ 'x': jnp.array([8., 10., 5.]), 'domain_id': jnp.array([1, 3, 1]) }), jax.random.PRNGKey(1)), ] next_state, client_diagnostics = algorithm.apply(state, clients) npt.assert_allclose(next_state.params['w'], 3.5555556) npt.assert_allclose( next_state.domain_weights, [0.08702461, 0.18604803, 0.2663479, 0.46057943]) npt.assert_allclose(next_state.domain_window, [[1., 2., 3., 4.], [3., 3., 2., 1.]]) npt.assert_allclose(client_diagnostics[b'cid0']['delta_l2_norm'], 2.8333335) npt.assert_allclose(client_diagnostics[b'cid1']['delta_l2_norm'], 7.666667) with self.subTest('invalid init_domain_weights'): with self.assertRaisesRegex( ValueError, 'init_domain_weights must sum to approximately 1.'): agnostic_fed_avg.agnostic_federated_averaging( per_example_loss=per_example_loss, client_optimizer=optimizers.sgd(learning_rate=1.0), server_optimizer=optimizers.sgd(learning_rate=1.0), client_batch_hparams=client_datasets. ShuffleRepeatBatchHParams(batch_size=3), domain_batch_hparams=client_datasets.PaddedBatchHParams( batch_size=3), init_domain_weights=[50., 0., 0., 0.], domain_learning_rate=0.5) with self.subTest('unequal lengths'): with self.assertRaisesRegex( ValueError, 'init_domain_weights and init_domain_window must be equal lengths.' ): agnostic_fed_avg.agnostic_federated_averaging( per_example_loss=per_example_loss, client_optimizer=optimizers.sgd(learning_rate=1.0), server_optimizer=optimizers.sgd(learning_rate=1.0), client_batch_hparams=client_datasets. ShuffleRepeatBatchHParams(batch_size=3), domain_batch_hparams=client_datasets.PaddedBatchHParams( batch_size=3), init_domain_weights=[0.1, 0.2, 0.3, 0.4], domain_learning_rate=0.5, init_domain_window=[1, 2])
def test_hyp_cluster(self): functions_called = set() def per_example_loss(params, batch, rng): self.assertIsNotNone(rng) functions_called.add('per_example_loss') return jnp.square(params - batch['x']) def regularizer(params): del params functions_called.add('regularizer') return 0 client_optimizer = optimizers.sgd(0.5) server_optimizer = optimizers.sgd(0.25) maximization_batch_hparams = client_datasets.PaddedBatchHParams( batch_size=2) expectation_batch_hparams = client_datasets.ShuffleRepeatBatchHParams( batch_size=1, num_epochs=5) algorithm = hyp_cluster.hyp_cluster( per_example_loss=per_example_loss, client_optimizer=client_optimizer, server_optimizer=server_optimizer, maximization_batch_hparams=maximization_batch_hparams, expectation_batch_hparams=expectation_batch_hparams, regularizer=regularizer) init_state = algorithm.init([jnp.array(1.), jnp.array(-1.)]) # Nothing happens with empty data. no_op_state, diagnostics = algorithm.apply(init_state, clients=[]) npt.assert_array_equal(init_state.cluster_params, no_op_state.cluster_params) self.assertEmpty(diagnostics) # Some actual training. PRNGKeys are not actually used. clients = [ (b'0', client_datasets.ClientDataset({'x': np.array([1.1])}), jax.random.PRNGKey(0)), (b'1', client_datasets.ClientDataset({'x': np.array([0.9, 0.9])}), jax.random.PRNGKey(1)), (b'2', client_datasets.ClientDataset({'x': np.array([-1.1])}), jax.random.PRNGKey(2)), (b'3', client_datasets.ClientDataset({'x': np.array([-0.9, -0.9, -0.9])}), jax.random.PRNGKey(3)), ] next_state, diagnostics = algorithm.apply(init_state, clients) npt.assert_equal( diagnostics, { b'0': { 'cluster_id': 0 }, b'1': { 'cluster_id': 0 }, b'2': { 'cluster_id': 1 }, b'3': { 'cluster_id': 1 }, }) cluster_params = next_state.cluster_params self.assertIsInstance(cluster_params, list) self.assertLen(cluster_params, 2) npt.assert_allclose(cluster_params[0], [1. - 0.25 * 0.1 / 3]) npt.assert_allclose(cluster_params[1], [-1. + 0.25 * 0.2 / 4]) self.assertCountEqual(functions_called, ['per_example_loss', 'regularizer'])
def test_hyp_cluster_evaluator(self): functions_called = set() def apply_for_eval(params, batch): functions_called.add('apply_for_eval') score = params * batch['x'] return jnp.stack([-score, score], axis=-1) def apply_for_train(params, batch, rng): functions_called.add('apply_for_train') self.assertIsNotNone(rng) return params * batch['x'] def train_loss(batch, out): functions_called.add('train_loss') return jnp.abs(batch['y'] * 2 - 1 - out) def regularizer(params): # Just to check regularizer is called. del params functions_called.add('regularizer') return 0 evaluator = hyp_cluster.HypClusterEvaluator( models.Model(init=None, apply_for_eval=apply_for_eval, apply_for_train=apply_for_train, train_loss=train_loss, eval_metrics={'accuracy': metrics.Accuracy()}), regularizer) cluster_params = [jnp.array(1.), jnp.array(-1.)] train_clients = [ # Evaluated using cluster 0. (b'0', client_datasets.ClientDataset({ 'x': np.array([3., 2., 1.]), 'y': np.array([1, 1, 0]) }), jax.random.PRNGKey(0)), # Evaluated using cluster 1. (b'1', client_datasets.ClientDataset({ 'x': np.array([0.9, -0.9, 0.8, -0.8, -0.3]), 'y': np.array([0, 1, 0, 1, 0]) }), jax.random.PRNGKey(1)), ] # Test clients are generated from train_clients by swapping client ids and # then flipping labels. test_clients = [ # Evaluated using cluster 0. (b'0', client_datasets.ClientDataset({ 'x': np.array([0.9, -0.9, 0.8, -0.8, -0.3]), 'y': np.array([1, 0, 1, 0, 1]) })), # Evaluated using cluster 1. (b'1', client_datasets.ClientDataset({ 'x': np.array([3., 2., 1.]), 'y': np.array([0, 0, 1]) })), ] for batch_size in [1, 2, 4]: with self.subTest(f'batch_size = {batch_size}'): batch_hparams = client_datasets.PaddedBatchHParams( batch_size=batch_size) metric_values = dict( evaluator.evaluate_clients(cluster_params=cluster_params, train_clients=train_clients, test_clients=test_clients, batch_hparams=batch_hparams)) self.assertCountEqual(metric_values, [b'0', b'1']) self.assertCountEqual(metric_values[b'0'], ['accuracy']) npt.assert_allclose(metric_values[b'0']['accuracy'], 4 / 5) self.assertCountEqual(metric_values[b'1'], ['accuracy']) npt.assert_allclose(metric_values[b'1']['accuracy'], 2 / 3) self.assertCountEqual( functions_called, ['apply_for_train', 'train_loss', 'apply_for_eval', 'regularizer'])
def test_maximization_step(self): # L1 distance from centers. def per_example_loss(params, batch, rng): # Randomly flip the center to test rng behavior. sign = jax.random.bernoulli(rng) * 2 - 1 return jnp.abs(params * sign - batch['x']) def regularizer(params): return 0.01 * jnp.sum(jnp.abs(params)) evaluator = models.AverageLossEvaluator(per_example_loss, regularizer) cluster_params = [jnp.array(0.), jnp.array(-1.), jnp.array(2.)] # Batch size is chosen so that we run 1 or 2 batches. batch_hparams = client_datasets.PaddedBatchHParams(batch_size=2) # Special seeds: # - No flip in first 2 steps for all 3 clusters: 0; # - Flip all in first 2 steps for all 3 clusters: 16; # - No flip then flip all for all 3 clusters: 68; # - Flips only cluster 1 in first 2 steps: 106. clients = [ # No flip in first 2 steps for all 3 clusters. (b'near0', client_datasets.ClientDataset({'x': np.array([0.1])}), jax.random.PRNGKey(0)), # Flip all in first 2 steps for all 3 clusters. (b'near-1', client_datasets.ClientDataset({'x': np.array([0.9, 1.1, 1.3])}), jax.random.PRNGKey(16)), # No flip then flip all for all 3 clusters. (b'near2', client_datasets.ClientDataset({'x': np.array([1.9, 2.1, -2.1])}), jax.random.PRNGKey(68)), # Flips only cluster 1 in first 2 steps. (b'near1', client_datasets.ClientDataset({'x': np.array([0.9, 1.1, 1.3])}), jax.random.PRNGKey(106)), ] cluster_losses = hyp_cluster._cluster_losses( evaluator=evaluator, cluster_params=cluster_params, clients=clients, batch_hparams=batch_hparams) self.assertCountEqual(cluster_losses, [b'near0', b'near-1', b'near2', b'near1']) npt.assert_allclose(cluster_losses[b'near0'], np.array([0.1, 1.1 + 0.01, 1.9 + 0.02])) npt.assert_allclose(cluster_losses[b'near-1'], np.array([1.1, 0.5 / 3 + 0.01, 3.1 + 0.02])) npt.assert_allclose(cluster_losses[b'near2'], np.array([6.1 / 3, 9.1 / 3 + 0.01, 0.1 + 0.02]), rtol=1e-6) npt.assert_allclose(cluster_losses[b'near1'], np.array([1.1, 0.5 / 3 + 0.01, 0.9 + 0.02])) client_cluster_ids = hyp_cluster.maximization_step( evaluator=evaluator, cluster_params=cluster_params, clients=clients, batch_hparams=batch_hparams) self.assertDictEqual(client_cluster_ids, { b'near0': 0, b'near-1': 1, b'near2': 2, b'near1': 1 })
def test_kmeans_init(self): functions_called = set() def init(rng): functions_called.add('init') return jax.random.uniform(rng) def apply_for_train(params, batch, rng): functions_called.add('apply_for_train') self.assertIsNotNone(rng) return params - batch['x'] def train_loss(batch, out): functions_called.add('train_loss') return jnp.square(out) + batch['bias'] def regularizer(params): del params functions_called.add('regularizer') return 0 initializer = hyp_cluster.ModelKMeansInitializer( models.Model(init=init, apply_for_train=apply_for_train, apply_for_eval=None, train_loss=train_loss, eval_metrics={}), optimizers.sgd(0.5), regularizer) # Each client has 1 example, so it's very easy to reach minimal loss, at # which point the loss entirely depends on bias. clients = [ (b'0', client_datasets.ClientDataset({ 'x': np.array([1.01]), 'bias': np.array([-2.]) }), jax.random.PRNGKey(1)), (b'1', client_datasets.ClientDataset({ 'x': np.array([3.02]), 'bias': np.array([-1.]) }), jax.random.PRNGKey(2)), (b'2', client_datasets.ClientDataset({ 'x': np.array([3.03]), 'bias': np.array([1.]) }), jax.random.PRNGKey(3)), (b'3', client_datasets.ClientDataset({ 'x': np.array([1.04]), 'bias': np.array([2.]) }), jax.random.PRNGKey(3)), ] train_batch_hparams = client_datasets.ShuffleRepeatBatchHParams( batch_size=1, num_epochs=5) eval_batch_hparams = client_datasets.PaddedBatchHParams(batch_size=2) # Using a rng that leads to b'0' being the initial center. cluster_params = initializer.cluster_params( num_clusters=3, rng=jax.random.PRNGKey(0), clients=clients, train_batch_hparams=train_batch_hparams, eval_batch_hparams=eval_batch_hparams) self.assertIsInstance(cluster_params, list) self.assertLen(cluster_params, 3) npt.assert_allclose(cluster_params, [1.01, 3.03, 1.04]) self.assertCountEqual( functions_called, ['init', 'apply_for_train', 'train_loss', 'regularizer']) # Using a rng that leads to b'2' being the initial center. cluster_params = initializer.cluster_params( num_clusters=3, rng=jax.random.PRNGKey(1), clients=clients, train_batch_hparams=train_batch_hparams, eval_batch_hparams=eval_batch_hparams) self.assertIsInstance(cluster_params, list) self.assertLen(cluster_params, 3) npt.assert_allclose(cluster_params, [3.03, 1.04, 1.04])