def test_fed_prox(self): client_optimizer = optimizers.sgd(learning_rate=1.0) server_optimizer = optimizers.sgd(learning_rate=1.0) client_batch_hparams = client_datasets.ShuffleRepeatBatchHParams( batch_size=2, num_epochs=1, seed=0) algorithm = fed_prox.fed_prox(per_example_loss, client_optimizer, server_optimizer, client_batch_hparams, proximal_weight=0.01) with self.subTest('init'): state = algorithm.init({'w': jnp.array(4.)}) npt.assert_array_equal(state.params['w'], 4.) self.assertLen(state.opt_state, 2) with self.subTest('apply'): clients = [ (b'cid0', client_datasets.ClientDataset({'x': jnp.array([2., 4., 6.])}), jax.random.PRNGKey(0)), (b'cid1', client_datasets.ClientDataset({'x': jnp.array([8., 10.])}), jax.random.PRNGKey(1)), ] state, client_diagnostics = algorithm.apply(state, clients) npt.assert_allclose(state.params['w'], -3.77) npt.assert_allclose(client_diagnostics[b'cid0']['delta_l2_norm'], 6.95) npt.assert_allclose(client_diagnostics[b'cid1']['delta_l2_norm'], 9.)
def test_federated_averaging(self): client_optimizer = optimizers.sgd(learning_rate=1.0) server_optimizer = optimizers.sgd(learning_rate=1.0) client_batch_hparams = client_datasets.ShuffleRepeatBatchHParams( batch_size=2, num_epochs=1, seed=0) algorithm = fed_avg.federated_averaging(grad_fn, client_optimizer, server_optimizer, client_batch_hparams) with self.subTest('init'): state = algorithm.init({'w': jnp.array([0., 2., 4.])}) npt.assert_array_equal(state.params['w'], [0., 2., 4.]) self.assertLen(state.opt_state, 2) with self.subTest('apply'): clients = [ (b'cid0', client_datasets.ClientDataset({'x': jnp.array([2., 4., 6.])}), jax.random.PRNGKey(0)), (b'cid1', client_datasets.ClientDataset({'x': jnp.array([8., 10.])}), jax.random.PRNGKey(1)), ] state, client_diagnostics = algorithm.apply(state, clients) npt.assert_allclose(state.params['w'], [0., 1.5655555, 3.131111]) npt.assert_allclose(client_diagnostics[b'cid0']['delta_l2_norm'], 1.4534444262) npt.assert_allclose(client_diagnostics[b'cid1']['delta_l2_norm'], 0.2484521282)
def test_create_train_for_each_client(self): base_optimizer = optimizers.sgd(learning_rate=1.0) train_for_each_client = mime.create_train_for_each_client( grad_fn, base_optimizer) batch_clients = [ (b'cid0', [{ 'x': jnp.array([6., 4.]) }, { 'x': jnp.array([2., 2.]) }], jax.random.PRNGKey(0)), (b'cid1', [{ 'x': jnp.array([10., 8]) }], jax.random.PRNGKey(1)), ] server_params = {'w': jnp.array(4.)} server_opt_state = base_optimizer.init(server_params) server_grads = {'w': jnp.array(6.)} shared_input = { 'params': server_params, 'opt_state': server_opt_state, 'control_variate': server_grads } client_outputs = dict(train_for_each_client(shared_input, batch_clients)) npt.assert_allclose(client_outputs[b'cid0']['w'], 12.) npt.assert_allclose(client_outputs[b'cid1']['w'], 6.)
def test_mime(self): base_optimizer = optimizers.sgd(learning_rate=1.0) train_batch_hparams = client_datasets.ShuffleRepeatBatchHParams( batch_size=2, num_epochs=1, seed=0) grad_batch_hparams = client_datasets.PaddedBatchHParams(batch_size=2) server_learning_rate = 0.2 algorithm = mime.mime(per_example_loss, base_optimizer, train_batch_hparams, grad_batch_hparams, server_learning_rate) with self.subTest('init'): state = algorithm.init({'w': jnp.array(4.)}) npt.assert_equal(state.params, {'w': jnp.array(4.)}) self.assertLen(state.opt_state, 2) with self.subTest('apply'): clients = [ (b'cid0', client_datasets.ClientDataset({'x': jnp.array([2., 4., 6.])}), jax.random.PRNGKey(0)), (b'cid1', client_datasets.ClientDataset({'x': jnp.array([8., 10.])}), jax.random.PRNGKey(1)), ] state, client_diagnostics = algorithm.apply(state, clients) npt.assert_allclose(state.params['w'], 2.08) npt.assert_allclose(client_diagnostics[b'cid0']['delta_l2_norm'], 12.) npt.assert_allclose(client_diagnostics[b'cid1']['delta_l2_norm'], 6.)
def test_create_train_for_each_client(self): num_domains = 4 shared_input = { 'params': { 'w': jnp.array(4.) }, 'alpha': jnp.array([0.1, 0.2, 0.3, 0.4]) } batch_clients = [ (b'cid0', [{ 'x': jnp.array([1., 2., 4.]), 'domain_id': jnp.array([1, 0, 0]), }, { 'x': jnp.array([3., 6., 1.]), 'domain_id': jnp.array([0, 2, 2]), }], { 'rng': jax.random.PRNGKey(0), 'beta': jnp.array(0.5) }), (b'cid1', [{ 'x': jnp.array([8., 10., 5.]), 'domain_id': jnp.array([1, 3, 1]), }], { 'rng': jax.random.PRNGKey(1), 'beta': jnp.array(0.2) }), ] client_optimizer = optimizers.sgd(learning_rate=1.0) func = agnostic_fed_avg.create_train_for_each_client( per_example_loss, client_optimizer, num_domains) client_delta_params = dict(func(shared_input, batch_clients)) npt.assert_allclose(client_delta_params[b'cid0']['w'], 6.4) npt.assert_allclose(client_delta_params[b'cid1']['w'], 33.)
def test_client_delta_clip_norm(self): base_optimizer = optimizers.sgd(learning_rate=1.0) train_batch_hparams = client_datasets.ShuffleRepeatBatchHParams( batch_size=2, num_epochs=1, seed=0) grad_batch_hparams = client_datasets.PaddedBatchHParams(batch_size=2) server_learning_rate = 0.2 algorithm = mime_lite.mime_lite(per_example_loss, base_optimizer, train_batch_hparams, grad_batch_hparams, server_learning_rate, client_delta_clip_norm=0.5) clients = [ (b'cid0', client_datasets.ClientDataset({'x': jnp.array([0.2, 0.4, 0.6])}), jax.random.PRNGKey(0)), (b'cid1', client_datasets.ClientDataset({'x': jnp.array([0.8, 0.1])}), jax.random.PRNGKey(1)), ] state = algorithm.init({'w': jnp.array(4.)}) state, client_diagnostics = algorithm.apply(state, clients) npt.assert_allclose(state.params['w'], 3.904) npt.assert_allclose( client_diagnostics[b'cid0']['clipped_delta_l2_norm'], 0.5) npt.assert_allclose( client_diagnostics[b'cid1']['clipped_delta_l2_norm'], 0.45000005)
def test_ignore_grads_haiku(self): params = hk.data_structures.to_immutable_dict({ 'linear_1': { 'w': jnp.array([1., 1., 1.]) }, 'linear_2': { 'w': jnp.array([2., 2., 2.]), 'b': jnp.array([3., 3., 3.]) } }) grads = jax.tree_util.tree_map(lambda _: 0.5, params) ignore_optimizer = optimizers.ignore_grads_haiku( optimizer=optimizers.sgd(learning_rate=1.0), non_trainable_names=[('linear_1', 'w'), ('linear_2', 'b')]) opt_state = ignore_optimizer.init(params) opt_state, updated_params = ignore_optimizer.apply(grads, opt_state, params) jax.tree_util.tree_multimap( npt.assert_array_equal, updated_params, hk.data_structures.to_immutable_dict({ 'linear_1': { 'w': jnp.array([1., 1., 1.]) }, 'linear_2': { 'w': jnp.array([1.5, 1.5, 1.5]), 'b': jnp.array([3., 3., 3.]) } }))
def test_create_train_for_each_client(self): proximal_weight = 0.01 def fed_prox_loss(params, server_params, batch, rng): example_loss = per_example_loss(params, batch, rng) proximal_loss = 0.5 * proximal_weight * tree_util.tree_l2_squared( jax.tree_util.tree_map(lambda a, b: a - b, server_params, params)) return jnp.mean(example_loss + proximal_loss) grad_fn = jax.grad(fed_prox_loss) client_optimizer = optimizers.sgd(learning_rate=1.0) train_for_each_client = fed_prox.create_train_for_each_client( grad_fn, client_optimizer) batched_clients = [ (b'cid0', [{ 'x': jnp.array([2., 4., 6.]) }, { 'x': jnp.array([8., 10., 12.]) }], jax.random.PRNGKey(0)), (b'cid1', [{ 'x': jnp.array([1., 3., 5.]) }, { 'x': jnp.array([7., 9., 11.]) }], jax.random.PRNGKey(1)), ] server_params = {'w': jnp.array(4.0)} client_outputs = dict( train_for_each_client(server_params, batched_clients)) npt.assert_allclose(client_outputs[b'cid0']['w'], 13.96) npt.assert_allclose(client_outputs[b'cid1']['w'], 11.97)
def get(self) -> optimizers.Optimizer: """Gets the specified optimizer.""" optimizer_name = self._get_flag('optimizer') learning_rate = self._get_flag('learning_rate') if optimizer_name == 'sgd': return optimizers.sgd(learning_rate) elif optimizer_name == 'momentum': return optimizers.sgd(learning_rate, self._get_flag('momentum')) elif optimizer_name == 'adam': return optimizers.adam(learning_rate, self._get_flag('adam_beta1'), self._get_flag('adam_beta2'), self._get_flag('adam_epsilon')) elif optimizer_name == 'rmsprop': return optimizers.rmsprop(learning_rate, self._get_flag('rmsprop_decay'), self._get_flag('rmsprop_epsilon')) elif optimizer_name == 'adagrad': return optimizers.adagrad(learning_rate, eps=self._get_flag('adagrad_epsilon')) else: raise ValueError(f'Unsupported optimizer {optimizer_name!r} from ' f'--{self._prefix}optimizer.')
def test_create_train_for_each_client(self): client_optimizer = optimizers.sgd(learning_rate=1.0) train_for_each_client = fed_avg.create_train_for_each_client( grad_fn, client_optimizer) batched_clients = [ (b'cid0', [{ 'x': jnp.array([2., 4., 6.]) }, { 'x': jnp.array([8., 10., 12.]) }], jax.random.PRNGKey(0)), (b'cid1', [{ 'x': jnp.array([1., 3., 5.]) }, { 'x': jnp.array([7., 9., 11.]) }], jax.random.PRNGKey(1)), ] server_params = {'w': jnp.array(4.0)} client_outputs = dict( train_for_each_client(server_params, batched_clients)) npt.assert_allclose(client_outputs[b'cid0']['w'], 0.45555544) npt.assert_allclose(client_outputs[b'cid1']['w'], 0.5761316)
def test_expectation_step(self): def per_example_loss(params, batch, rng): self.assertIsNotNone(rng) return jnp.square(params - batch['x']) trainer = hyp_cluster.ClientDeltaTrainer(models.grad(per_example_loss), optimizers.sgd(0.5)) batch_hparams = client_datasets.ShuffleRepeatBatchHParams(batch_size=1, num_epochs=5) cluster_params = [jnp.array(1.), jnp.array(-1.), jnp.array(3.14)] client_cluster_ids = {b'0': 0, b'1': 0, b'2': 1, b'3': 1, b'4': 0} # RNGs are not actually used. clients = [ (b'0', client_datasets.ClientDataset({'x': np.array([1.1])}), jax.random.PRNGKey(0)), (b'1', client_datasets.ClientDataset({'x': np.array([0.9, 0.9])}), jax.random.PRNGKey(1)), (b'2', client_datasets.ClientDataset({'x': np.array([-1.1])}), jax.random.PRNGKey(2)), (b'3', client_datasets.ClientDataset({'x': np.array([-0.9, -0.9, -0.9])}), jax.random.PRNGKey(3)), (b'4', client_datasets.ClientDataset({'x': np.array([-0.1])}), jax.random.PRNGKey(4)), ] cluster_delta_params = hyp_cluster.expectation_step( trainer=trainer, cluster_params=cluster_params, client_cluster_ids=client_cluster_ids, clients=clients, batch_hparams=batch_hparams) self.assertIsInstance(cluster_delta_params, list) self.assertLen(cluster_delta_params, 3) npt.assert_allclose(cluster_delta_params[0], (-0.1 + 0.1 * 2 + 1.1) / 4) npt.assert_allclose(cluster_delta_params[1], (0.1 - 0.1 * 3) / 4, rtol=1e-6) self.assertIsNone(cluster_delta_params[2])
def test_create_train_for_each_client(self): base_optimizer = optimizers.sgd(learning_rate=1.0) train_for_each_client = mime_lite.create_train_for_each_client( grad_fn, base_optimizer) batch_clients = [ (b'cid0', [{ 'x': jnp.array([0.6, 0.4]) }, { 'x': jnp.array([0.2, 0.2]) }], jax.random.PRNGKey(0)), (b'cid1', [{ 'x': jnp.array([0.1, 0.8]) }], jax.random.PRNGKey(1)), ] server_params = {'w': jnp.array(4.)} server_opt_state = base_optimizer.init(server_params) shared_input = { 'params': server_params, 'opt_state': server_opt_state, } client_outputs = dict( train_for_each_client(shared_input, batch_clients)) npt.assert_allclose(client_outputs[b'cid0']['w'], 0.7) npt.assert_allclose(client_outputs[b'cid1']['w'], 0.45000002)
def test_agnostic_federated_averaging(self): algorithm = agnostic_fed_avg.agnostic_federated_averaging( per_example_loss=per_example_loss, client_optimizer=optimizers.sgd(learning_rate=1.0), server_optimizer=optimizers.sgd(learning_rate=0.1), client_batch_hparams=client_datasets.ShuffleRepeatBatchHParams( batch_size=3, num_epochs=1, seed=0), domain_batch_hparams=client_datasets.PaddedBatchHParams( batch_size=3), init_domain_weights=[0.1, 0.2, 0.3, 0.4], domain_learning_rate=0.01, domain_algorithm='eg', domain_window_size=2, init_domain_window=[1., 2., 3., 4.]) with self.subTest('init'): state = algorithm.init({'w': jnp.array(4.)}) npt.assert_equal(state.params, {'w': jnp.array(4.)}) self.assertLen(state.opt_state, 2) npt.assert_allclose(state.domain_weights, [0.1, 0.2, 0.3, 0.4]) npt.assert_allclose(state.domain_window, [[1., 2., 3., 4.], [1., 2., 3., 4.]]) with self.subTest('apply'): clients = [ (b'cid0', client_datasets.ClientDataset({ 'x': jnp.array([1., 2., 4., 3., 6., 1.]), 'domain_id': jnp.array([1, 0, 0, 0, 2, 2]) }), jax.random.PRNGKey(0)), (b'cid1', client_datasets.ClientDataset({ 'x': jnp.array([8., 10., 5.]), 'domain_id': jnp.array([1, 3, 1]) }), jax.random.PRNGKey(1)), ] next_state, client_diagnostics = algorithm.apply(state, clients) npt.assert_allclose(next_state.params['w'], 3.5555556) npt.assert_allclose( next_state.domain_weights, [0.08702461, 0.18604803, 0.2663479, 0.46057943]) npt.assert_allclose(next_state.domain_window, [[1., 2., 3., 4.], [3., 3., 2., 1.]]) npt.assert_allclose(client_diagnostics[b'cid0']['delta_l2_norm'], 2.8333335) npt.assert_allclose(client_diagnostics[b'cid1']['delta_l2_norm'], 7.666667) with self.subTest('invalid init_domain_weights'): with self.assertRaisesRegex( ValueError, 'init_domain_weights must sum to approximately 1.'): agnostic_fed_avg.agnostic_federated_averaging( per_example_loss=per_example_loss, client_optimizer=optimizers.sgd(learning_rate=1.0), server_optimizer=optimizers.sgd(learning_rate=1.0), client_batch_hparams=client_datasets. ShuffleRepeatBatchHParams(batch_size=3), domain_batch_hparams=client_datasets.PaddedBatchHParams( batch_size=3), init_domain_weights=[50., 0., 0., 0.], domain_learning_rate=0.5) with self.subTest('unequal lengths'): with self.assertRaisesRegex( ValueError, 'init_domain_weights and init_domain_window must be equal lengths.' ): agnostic_fed_avg.agnostic_federated_averaging( per_example_loss=per_example_loss, client_optimizer=optimizers.sgd(learning_rate=1.0), server_optimizer=optimizers.sgd(learning_rate=1.0), client_batch_hparams=client_datasets. ShuffleRepeatBatchHParams(batch_size=3), domain_batch_hparams=client_datasets.PaddedBatchHParams( batch_size=3), init_domain_weights=[0.1, 0.2, 0.3, 0.4], domain_learning_rate=0.5, init_domain_window=[1, 2])
def test_hyp_cluster(self): functions_called = set() def per_example_loss(params, batch, rng): self.assertIsNotNone(rng) functions_called.add('per_example_loss') return jnp.square(params - batch['x']) def regularizer(params): del params functions_called.add('regularizer') return 0 client_optimizer = optimizers.sgd(0.5) server_optimizer = optimizers.sgd(0.25) maximization_batch_hparams = client_datasets.PaddedBatchHParams( batch_size=2) expectation_batch_hparams = client_datasets.ShuffleRepeatBatchHParams( batch_size=1, num_epochs=5) algorithm = hyp_cluster.hyp_cluster( per_example_loss=per_example_loss, client_optimizer=client_optimizer, server_optimizer=server_optimizer, maximization_batch_hparams=maximization_batch_hparams, expectation_batch_hparams=expectation_batch_hparams, regularizer=regularizer) init_state = algorithm.init([jnp.array(1.), jnp.array(-1.)]) # Nothing happens with empty data. no_op_state, diagnostics = algorithm.apply(init_state, clients=[]) npt.assert_array_equal(init_state.cluster_params, no_op_state.cluster_params) self.assertEmpty(diagnostics) # Some actual training. PRNGKeys are not actually used. clients = [ (b'0', client_datasets.ClientDataset({'x': np.array([1.1])}), jax.random.PRNGKey(0)), (b'1', client_datasets.ClientDataset({'x': np.array([0.9, 0.9])}), jax.random.PRNGKey(1)), (b'2', client_datasets.ClientDataset({'x': np.array([-1.1])}), jax.random.PRNGKey(2)), (b'3', client_datasets.ClientDataset({'x': np.array([-0.9, -0.9, -0.9])}), jax.random.PRNGKey(3)), ] next_state, diagnostics = algorithm.apply(init_state, clients) npt.assert_equal( diagnostics, { b'0': { 'cluster_id': 0 }, b'1': { 'cluster_id': 0 }, b'2': { 'cluster_id': 1 }, b'3': { 'cluster_id': 1 }, }) cluster_params = next_state.cluster_params self.assertIsInstance(cluster_params, list) self.assertLen(cluster_params, 2) npt.assert_allclose(cluster_params[0], [1. - 0.25 * 0.1 / 3]) npt.assert_allclose(cluster_params[1], [-1. + 0.25 * 0.2 / 4]) self.assertCountEqual(functions_called, ['per_example_loss', 'regularizer'])
class ClientTrainerTest(absltest.TestCase): OPTIMIZER = optimizers.sgd(1) def test_train_global_params(self): def grad(params, batch, rng): return 0.5 * params + jnp.mean(batch['x']) + jax.random.uniform( rng, []) rng_0 = jax.random.PRNGKey(0) rng_uniform_00 = jax.random.uniform(jax.random.split(rng_0)[1], []) rng_uniform_01 = jax.random.uniform( jax.random.split(jax.random.split(rng_0)[0])[1], []) rng_1 = jax.random.PRNGKey(1) rng_uniform_10 = jax.random.uniform(jax.random.split(rng_1)[1], []) params = jnp.array(1.) clients = [(b'0000', [{ 'x': jnp.array([0.125]) }, { 'x': jnp.array([0.25, 0.75]) }], rng_0), (b'1001', [{ 'x': jnp.array([0, 1, 2]) }], rng_1)] client_delta_params = dict( hyp_cluster.ClientDeltaTrainer(grad, self.OPTIMIZER).train_global_params( params, clients)) self.assertCountEqual(client_delta_params, [b'0000', b'1001']) npt.assert_allclose( client_delta_params[b'0000'], (0.5 * 1 + 0.125 + rng_uniform_00) + (0.5 * (0.375 - rng_uniform_00) + 0.5 + rng_uniform_01)) npt.assert_allclose(client_delta_params[b'1001'], (0.5 * 1 + 1) + rng_uniform_10) def test_train_per_client_params(self): def grad(params, batch, rng): return 0.5 * params + jnp.mean(batch['x']) + jax.random.uniform( rng, []) rng_0 = jax.random.PRNGKey(0) rng_uniform_00 = jax.random.uniform(jax.random.split(rng_0)[1], []) rng_uniform_01 = jax.random.uniform( jax.random.split(jax.random.split(rng_0)[0])[1], []) rng_1 = jax.random.PRNGKey(1) rng_uniform_10 = jax.random.uniform(jax.random.split(rng_1)[1], []) clients = [ (b'0000', [{ 'x': jnp.array([0.125]) }, { 'x': jnp.array([0.25, 0.75]) }], rng_0, jnp.array(1.)), (b'1001', [{ 'x': jnp.array([0, 1, 2]) }], rng_1, jnp.array(0.5)), ] client_delta_params = dict( hyp_cluster.ClientDeltaTrainer( grad, self.OPTIMIZER).train_per_client_params(clients)) self.assertCountEqual(client_delta_params, [b'0000', b'1001']) npt.assert_allclose( client_delta_params[b'0000'], (0.5 * 1 + 0.125 + rng_uniform_00) + (0.5 * (0.375 - rng_uniform_00) + 0.5 + rng_uniform_01)) npt.assert_allclose(client_delta_params[b'1001'], (0.5 * 0.5 + 1) + rng_uniform_10) def test_return_params(self): def grad(params, batch, _): return 0.5 * params + jnp.mean(batch['x']) params = jnp.array(1.) clients = [(b'0000', [{ 'x': jnp.array([0.125]) }, { 'x': jnp.array([0.25, 0.75]) }], jax.random.PRNGKey(0)), (b'1001', [{ 'x': jnp.array([0, 1, 2]) }], jax.random.PRNGKey(1))] client_delta_params = dict( hyp_cluster.ClientParamsTrainer( grad, self.OPTIMIZER).train_global_params(params, clients)) npt.assert_equal( jax.device_get(client_delta_params), { b'0000': 1 - ((0.5 * 1 + 0.125) + (0.5 * 0.375 + 0.5)), b'1001': 1 - (0.5 * 1 + 1) })
def test_kmeans_init(self): functions_called = set() def init(rng): functions_called.add('init') return jax.random.uniform(rng) def apply_for_train(params, batch, rng): functions_called.add('apply_for_train') self.assertIsNotNone(rng) return params - batch['x'] def train_loss(batch, out): functions_called.add('train_loss') return jnp.square(out) + batch['bias'] def regularizer(params): del params functions_called.add('regularizer') return 0 initializer = hyp_cluster.ModelKMeansInitializer( models.Model(init=init, apply_for_train=apply_for_train, apply_for_eval=None, train_loss=train_loss, eval_metrics={}), optimizers.sgd(0.5), regularizer) # Each client has 1 example, so it's very easy to reach minimal loss, at # which point the loss entirely depends on bias. clients = [ (b'0', client_datasets.ClientDataset({ 'x': np.array([1.01]), 'bias': np.array([-2.]) }), jax.random.PRNGKey(1)), (b'1', client_datasets.ClientDataset({ 'x': np.array([3.02]), 'bias': np.array([-1.]) }), jax.random.PRNGKey(2)), (b'2', client_datasets.ClientDataset({ 'x': np.array([3.03]), 'bias': np.array([1.]) }), jax.random.PRNGKey(3)), (b'3', client_datasets.ClientDataset({ 'x': np.array([1.04]), 'bias': np.array([2.]) }), jax.random.PRNGKey(3)), ] train_batch_hparams = client_datasets.ShuffleRepeatBatchHParams( batch_size=1, num_epochs=5) eval_batch_hparams = client_datasets.PaddedBatchHParams(batch_size=2) # Using a rng that leads to b'0' being the initial center. cluster_params = initializer.cluster_params( num_clusters=3, rng=jax.random.PRNGKey(0), clients=clients, train_batch_hparams=train_batch_hparams, eval_batch_hparams=eval_batch_hparams) self.assertIsInstance(cluster_params, list) self.assertLen(cluster_params, 3) npt.assert_allclose(cluster_params, [1.01, 3.03, 1.04]) self.assertCountEqual( functions_called, ['init', 'apply_for_train', 'train_loss', 'regularizer']) # Using a rng that leads to b'2' being the initial center. cluster_params = initializer.cluster_params( num_clusters=3, rng=jax.random.PRNGKey(1), clients=clients, train_batch_hparams=train_batch_hparams, eval_batch_hparams=eval_batch_hparams) self.assertIsInstance(cluster_params, list) self.assertLen(cluster_params, 3) npt.assert_allclose(cluster_params, [3.03, 1.04, 1.04])