def test_mime(self): base_optimizer = optimizers.sgd(learning_rate=1.0) train_batch_hparams = client_datasets.ShuffleRepeatBatchHParams( batch_size=2, num_epochs=1, seed=0) grad_batch_hparams = client_datasets.PaddedBatchHParams(batch_size=2) server_learning_rate = 0.2 algorithm = mime.mime(per_example_loss, base_optimizer, train_batch_hparams, grad_batch_hparams, server_learning_rate) with self.subTest('init'): state = algorithm.init({'w': jnp.array(4.)}) npt.assert_equal(state.params, {'w': jnp.array(4.)}) self.assertLen(state.opt_state, 2) with self.subTest('apply'): clients = [ (b'cid0', client_datasets.ClientDataset({'x': jnp.array([2., 4., 6.])}), jax.random.PRNGKey(0)), (b'cid1', client_datasets.ClientDataset({'x': jnp.array([8., 10.])}), jax.random.PRNGKey(1)), ] state, client_diagnostics = algorithm.apply(state, clients) npt.assert_allclose(state.params['w'], 2.08) npt.assert_allclose(client_diagnostics[b'cid0']['delta_l2_norm'], 12.) npt.assert_allclose(client_diagnostics[b'cid1']['delta_l2_norm'], 6.)
def test_federated_averaging(self): client_optimizer = optimizers.sgd(learning_rate=1.0) server_optimizer = optimizers.sgd(learning_rate=1.0) client_batch_hparams = client_datasets.ShuffleRepeatBatchHParams( batch_size=2, num_epochs=1, seed=0) algorithm = fed_avg.federated_averaging(grad_fn, client_optimizer, server_optimizer, client_batch_hparams) with self.subTest('init'): state = algorithm.init({'w': jnp.array([0., 2., 4.])}) npt.assert_array_equal(state.params['w'], [0., 2., 4.]) self.assertLen(state.opt_state, 2) with self.subTest('apply'): clients = [ (b'cid0', client_datasets.ClientDataset({'x': jnp.array([2., 4., 6.])}), jax.random.PRNGKey(0)), (b'cid1', client_datasets.ClientDataset({'x': jnp.array([8., 10.])}), jax.random.PRNGKey(1)), ] state, client_diagnostics = algorithm.apply(state, clients) npt.assert_allclose(state.params['w'], [0., 1.5655555, 3.131111]) npt.assert_allclose(client_diagnostics[b'cid0']['delta_l2_norm'], 1.4534444262) npt.assert_allclose(client_diagnostics[b'cid1']['delta_l2_norm'], 0.2484521282)
def test_client_delta_clip_norm(self): base_optimizer = optimizers.sgd(learning_rate=1.0) train_batch_hparams = client_datasets.ShuffleRepeatBatchHParams( batch_size=2, num_epochs=1, seed=0) grad_batch_hparams = client_datasets.PaddedBatchHParams(batch_size=2) server_learning_rate = 0.2 algorithm = mime_lite.mime_lite(per_example_loss, base_optimizer, train_batch_hparams, grad_batch_hparams, server_learning_rate, client_delta_clip_norm=0.5) clients = [ (b'cid0', client_datasets.ClientDataset({'x': jnp.array([0.2, 0.4, 0.6])}), jax.random.PRNGKey(0)), (b'cid1', client_datasets.ClientDataset({'x': jnp.array([0.8, 0.1])}), jax.random.PRNGKey(1)), ] state = algorithm.init({'w': jnp.array(4.)}) state, client_diagnostics = algorithm.apply(state, clients) npt.assert_allclose(state.params['w'], 3.904) npt.assert_allclose( client_diagnostics[b'cid0']['clipped_delta_l2_norm'], 0.5) npt.assert_allclose( client_diagnostics[b'cid1']['clipped_delta_l2_norm'], 0.45000005)
def test_fed_prox(self): client_optimizer = optimizers.sgd(learning_rate=1.0) server_optimizer = optimizers.sgd(learning_rate=1.0) client_batch_hparams = client_datasets.ShuffleRepeatBatchHParams( batch_size=2, num_epochs=1, seed=0) algorithm = fed_prox.fed_prox(per_example_loss, client_optimizer, server_optimizer, client_batch_hparams, proximal_weight=0.01) with self.subTest('init'): state = algorithm.init({'w': jnp.array(4.)}) npt.assert_array_equal(state.params['w'], 4.) self.assertLen(state.opt_state, 2) with self.subTest('apply'): clients = [ (b'cid0', client_datasets.ClientDataset({'x': jnp.array([2., 4., 6.])}), jax.random.PRNGKey(0)), (b'cid1', client_datasets.ClientDataset({'x': jnp.array([8., 10.])}), jax.random.PRNGKey(1)), ] state, client_diagnostics = algorithm.apply(state, clients) npt.assert_allclose(state.params['w'], -3.77) npt.assert_allclose(client_diagnostics[b'cid0']['delta_l2_norm'], 6.95) npt.assert_allclose(client_diagnostics[b'cid1']['delta_l2_norm'], 9.)
def __init__(self, name: Optional[str] = None, default_batch_size: int = 128): super().__init__(name) defaults = client_datasets.ShuffleRepeatBatchHParams(batch_size=-1) # TODO(wuke): Support other fields. self._integer('batch_size', default_batch_size, 'Batch size') self._integer('num_epochs', defaults.num_epochs, 'Number of epochs') self._integer('num_steps', defaults.num_steps, 'Number of steps')
def test_get(self): with self.subTest('default'): self.assertEqual( self.SHUFFLE_BATCH_REPEAT.get(), client_datasets.ShuffleRepeatBatchHParams(batch_size=128)) self.assertEqual( self.PADDED_BATCH.get(), client_datasets.PaddedBatchHParams(batch_size=128)) self.assertEqual(self.BATCH.get(), client_datasets.BatchHParams(batch_size=128)) with self.subTest('custom'): with flagsaver.flagsaver(shuffle_batch_repeat_batch_size=12, shuffle_batch_repeat_num_epochs=21, padded_batch_size=34, batch_size=56): self.assertEqual( self.SHUFFLE_BATCH_REPEAT.get(), client_datasets.ShuffleRepeatBatchHParams(batch_size=12, num_epochs=21)) self.assertEqual( self.PADDED_BATCH.get(), client_datasets.PaddedBatchHParams(batch_size=34)) self.assertEqual(self.BATCH.get(), client_datasets.BatchHParams(batch_size=56))
def test_expectation_step(self): def per_example_loss(params, batch, rng): self.assertIsNotNone(rng) return jnp.square(params - batch['x']) trainer = hyp_cluster.ClientDeltaTrainer(models.grad(per_example_loss), optimizers.sgd(0.5)) batch_hparams = client_datasets.ShuffleRepeatBatchHParams(batch_size=1, num_epochs=5) cluster_params = [jnp.array(1.), jnp.array(-1.), jnp.array(3.14)] client_cluster_ids = {b'0': 0, b'1': 0, b'2': 1, b'3': 1, b'4': 0} # RNGs are not actually used. clients = [ (b'0', client_datasets.ClientDataset({'x': np.array([1.1])}), jax.random.PRNGKey(0)), (b'1', client_datasets.ClientDataset({'x': np.array([0.9, 0.9])}), jax.random.PRNGKey(1)), (b'2', client_datasets.ClientDataset({'x': np.array([-1.1])}), jax.random.PRNGKey(2)), (b'3', client_datasets.ClientDataset({'x': np.array([-0.9, -0.9, -0.9])}), jax.random.PRNGKey(3)), (b'4', client_datasets.ClientDataset({'x': np.array([-0.1])}), jax.random.PRNGKey(4)), ] cluster_delta_params = hyp_cluster.expectation_step( trainer=trainer, cluster_params=cluster_params, client_cluster_ids=client_cluster_ids, clients=clients, batch_hparams=batch_hparams) self.assertIsInstance(cluster_delta_params, list) self.assertLen(cluster_delta_params, 3) npt.assert_allclose(cluster_delta_params[0], (-0.1 + 0.1 * 2 + 1.1) / 4) npt.assert_allclose(cluster_delta_params[1], (0.1 - 0.1 * 3) / 4, rtol=1e-6) self.assertIsNone(cluster_delta_params[2])
def test_agnostic_federated_averaging(self): algorithm = agnostic_fed_avg.agnostic_federated_averaging( per_example_loss=per_example_loss, client_optimizer=optimizers.sgd(learning_rate=1.0), server_optimizer=optimizers.sgd(learning_rate=0.1), client_batch_hparams=client_datasets.ShuffleRepeatBatchHParams( batch_size=3, num_epochs=1, seed=0), domain_batch_hparams=client_datasets.PaddedBatchHParams( batch_size=3), init_domain_weights=[0.1, 0.2, 0.3, 0.4], domain_learning_rate=0.01, domain_algorithm='eg', domain_window_size=2, init_domain_window=[1., 2., 3., 4.]) with self.subTest('init'): state = algorithm.init({'w': jnp.array(4.)}) npt.assert_equal(state.params, {'w': jnp.array(4.)}) self.assertLen(state.opt_state, 2) npt.assert_allclose(state.domain_weights, [0.1, 0.2, 0.3, 0.4]) npt.assert_allclose(state.domain_window, [[1., 2., 3., 4.], [1., 2., 3., 4.]]) with self.subTest('apply'): clients = [ (b'cid0', client_datasets.ClientDataset({ 'x': jnp.array([1., 2., 4., 3., 6., 1.]), 'domain_id': jnp.array([1, 0, 0, 0, 2, 2]) }), jax.random.PRNGKey(0)), (b'cid1', client_datasets.ClientDataset({ 'x': jnp.array([8., 10., 5.]), 'domain_id': jnp.array([1, 3, 1]) }), jax.random.PRNGKey(1)), ] next_state, client_diagnostics = algorithm.apply(state, clients) npt.assert_allclose(next_state.params['w'], 3.5555556) npt.assert_allclose( next_state.domain_weights, [0.08702461, 0.18604803, 0.2663479, 0.46057943]) npt.assert_allclose(next_state.domain_window, [[1., 2., 3., 4.], [3., 3., 2., 1.]]) npt.assert_allclose(client_diagnostics[b'cid0']['delta_l2_norm'], 2.8333335) npt.assert_allclose(client_diagnostics[b'cid1']['delta_l2_norm'], 7.666667) with self.subTest('invalid init_domain_weights'): with self.assertRaisesRegex( ValueError, 'init_domain_weights must sum to approximately 1.'): agnostic_fed_avg.agnostic_federated_averaging( per_example_loss=per_example_loss, client_optimizer=optimizers.sgd(learning_rate=1.0), server_optimizer=optimizers.sgd(learning_rate=1.0), client_batch_hparams=client_datasets. ShuffleRepeatBatchHParams(batch_size=3), domain_batch_hparams=client_datasets.PaddedBatchHParams( batch_size=3), init_domain_weights=[50., 0., 0., 0.], domain_learning_rate=0.5) with self.subTest('unequal lengths'): with self.assertRaisesRegex( ValueError, 'init_domain_weights and init_domain_window must be equal lengths.' ): agnostic_fed_avg.agnostic_federated_averaging( per_example_loss=per_example_loss, client_optimizer=optimizers.sgd(learning_rate=1.0), server_optimizer=optimizers.sgd(learning_rate=1.0), client_batch_hparams=client_datasets. ShuffleRepeatBatchHParams(batch_size=3), domain_batch_hparams=client_datasets.PaddedBatchHParams( batch_size=3), init_domain_weights=[0.1, 0.2, 0.3, 0.4], domain_learning_rate=0.5, init_domain_window=[1, 2])
def test_hyp_cluster(self): functions_called = set() def per_example_loss(params, batch, rng): self.assertIsNotNone(rng) functions_called.add('per_example_loss') return jnp.square(params - batch['x']) def regularizer(params): del params functions_called.add('regularizer') return 0 client_optimizer = optimizers.sgd(0.5) server_optimizer = optimizers.sgd(0.25) maximization_batch_hparams = client_datasets.PaddedBatchHParams( batch_size=2) expectation_batch_hparams = client_datasets.ShuffleRepeatBatchHParams( batch_size=1, num_epochs=5) algorithm = hyp_cluster.hyp_cluster( per_example_loss=per_example_loss, client_optimizer=client_optimizer, server_optimizer=server_optimizer, maximization_batch_hparams=maximization_batch_hparams, expectation_batch_hparams=expectation_batch_hparams, regularizer=regularizer) init_state = algorithm.init([jnp.array(1.), jnp.array(-1.)]) # Nothing happens with empty data. no_op_state, diagnostics = algorithm.apply(init_state, clients=[]) npt.assert_array_equal(init_state.cluster_params, no_op_state.cluster_params) self.assertEmpty(diagnostics) # Some actual training. PRNGKeys are not actually used. clients = [ (b'0', client_datasets.ClientDataset({'x': np.array([1.1])}), jax.random.PRNGKey(0)), (b'1', client_datasets.ClientDataset({'x': np.array([0.9, 0.9])}), jax.random.PRNGKey(1)), (b'2', client_datasets.ClientDataset({'x': np.array([-1.1])}), jax.random.PRNGKey(2)), (b'3', client_datasets.ClientDataset({'x': np.array([-0.9, -0.9, -0.9])}), jax.random.PRNGKey(3)), ] next_state, diagnostics = algorithm.apply(init_state, clients) npt.assert_equal( diagnostics, { b'0': { 'cluster_id': 0 }, b'1': { 'cluster_id': 0 }, b'2': { 'cluster_id': 1 }, b'3': { 'cluster_id': 1 }, }) cluster_params = next_state.cluster_params self.assertIsInstance(cluster_params, list) self.assertLen(cluster_params, 2) npt.assert_allclose(cluster_params[0], [1. - 0.25 * 0.1 / 3]) npt.assert_allclose(cluster_params[1], [-1. + 0.25 * 0.2 / 4]) self.assertCountEqual(functions_called, ['per_example_loss', 'regularizer'])
def test_kmeans_init(self): functions_called = set() def init(rng): functions_called.add('init') return jax.random.uniform(rng) def apply_for_train(params, batch, rng): functions_called.add('apply_for_train') self.assertIsNotNone(rng) return params - batch['x'] def train_loss(batch, out): functions_called.add('train_loss') return jnp.square(out) + batch['bias'] def regularizer(params): del params functions_called.add('regularizer') return 0 initializer = hyp_cluster.ModelKMeansInitializer( models.Model(init=init, apply_for_train=apply_for_train, apply_for_eval=None, train_loss=train_loss, eval_metrics={}), optimizers.sgd(0.5), regularizer) # Each client has 1 example, so it's very easy to reach minimal loss, at # which point the loss entirely depends on bias. clients = [ (b'0', client_datasets.ClientDataset({ 'x': np.array([1.01]), 'bias': np.array([-2.]) }), jax.random.PRNGKey(1)), (b'1', client_datasets.ClientDataset({ 'x': np.array([3.02]), 'bias': np.array([-1.]) }), jax.random.PRNGKey(2)), (b'2', client_datasets.ClientDataset({ 'x': np.array([3.03]), 'bias': np.array([1.]) }), jax.random.PRNGKey(3)), (b'3', client_datasets.ClientDataset({ 'x': np.array([1.04]), 'bias': np.array([2.]) }), jax.random.PRNGKey(3)), ] train_batch_hparams = client_datasets.ShuffleRepeatBatchHParams( batch_size=1, num_epochs=5) eval_batch_hparams = client_datasets.PaddedBatchHParams(batch_size=2) # Using a rng that leads to b'0' being the initial center. cluster_params = initializer.cluster_params( num_clusters=3, rng=jax.random.PRNGKey(0), clients=clients, train_batch_hparams=train_batch_hparams, eval_batch_hparams=eval_batch_hparams) self.assertIsInstance(cluster_params, list) self.assertLen(cluster_params, 3) npt.assert_allclose(cluster_params, [1.01, 3.03, 1.04]) self.assertCountEqual( functions_called, ['init', 'apply_for_train', 'train_loss', 'regularizer']) # Using a rng that leads to b'2' being the initial center. cluster_params = initializer.cluster_params( num_clusters=3, rng=jax.random.PRNGKey(1), clients=clients, train_batch_hparams=train_batch_hparams, eval_batch_hparams=eval_batch_hparams) self.assertIsInstance(cluster_params, list) self.assertLen(cluster_params, 3) npt.assert_allclose(cluster_params, [3.03, 1.04, 1.04])
def get(self): return client_datasets.ShuffleRepeatBatchHParams( batch_size=self._get_flag('batch_size'), num_epochs=self._get_flag('num_epochs'), num_steps=self._get_flag('num_steps'))