def test_expectation_step(self): def per_example_loss(params, batch, rng): self.assertIsNotNone(rng) return jnp.square(params - batch['x']) trainer = hyp_cluster.ClientDeltaTrainer(models.grad(per_example_loss), optimizers.sgd(0.5)) batch_hparams = client_datasets.ShuffleRepeatBatchHParams(batch_size=1, num_epochs=5) cluster_params = [jnp.array(1.), jnp.array(-1.), jnp.array(3.14)] client_cluster_ids = {b'0': 0, b'1': 0, b'2': 1, b'3': 1, b'4': 0} # RNGs are not actually used. clients = [ (b'0', client_datasets.ClientDataset({'x': np.array([1.1])}), jax.random.PRNGKey(0)), (b'1', client_datasets.ClientDataset({'x': np.array([0.9, 0.9])}), jax.random.PRNGKey(1)), (b'2', client_datasets.ClientDataset({'x': np.array([-1.1])}), jax.random.PRNGKey(2)), (b'3', client_datasets.ClientDataset({'x': np.array([-0.9, -0.9, -0.9])}), jax.random.PRNGKey(3)), (b'4', client_datasets.ClientDataset({'x': np.array([-0.1])}), jax.random.PRNGKey(4)), ] cluster_delta_params = hyp_cluster.expectation_step( trainer=trainer, cluster_params=cluster_params, client_cluster_ids=client_cluster_ids, clients=clients, batch_hparams=batch_hparams) self.assertIsInstance(cluster_delta_params, list) self.assertLen(cluster_delta_params, 3) npt.assert_allclose(cluster_delta_params[0], (-0.1 + 0.1 * 2 + 1.1) / 4) npt.assert_allclose(cluster_delta_params[1], (0.1 - 0.1 * 3) / 4, rtol=1e-6) self.assertIsNone(cluster_delta_params[2])
def mime( per_example_loss: Callable[[Params, BatchExample, PRNGKey], jnp.ndarray], base_optimizer: optimizers.Optimizer, client_batch_hparams: client_datasets.ShuffleRepeatBatchHParams, grads_batch_hparams: client_datasets.PaddedBatchHParams, server_learning_rate: float, regularizer: Optional[Callable[[Params], jnp.ndarray]] = None ) -> federated_algorithm.FederatedAlgorithm: """Builds mime. Args: per_example_loss: A function from (params, batch_example, rng) to a vector of loss values for each example in the batch. This is used in both the server gradient computation and gradient descent training. base_optimizer: Base optimizer to mimic. client_batch_hparams: Hyperparameters for batching client dataset for train. grads_batch_hparams: Hyperparameters for batching client dataset for server gradient computation. server_learning_rate: Server learning rate. regularizer: Optional regularizer that only depends on params. Returns: FederatedAlgorithm """ grad_fn = models.grad(per_example_loss, regularizer) grads_for_each_client = create_grads_for_each_client(grad_fn) train_for_each_client = create_train_for_each_client( grad_fn, base_optimizer) def init(params: Params) -> ServerState: opt_state = base_optimizer.init(params) return ServerState(params, opt_state) def apply( server_state: ServerState, clients: Sequence[Tuple[federated_data.ClientId, client_datasets.ClientDataset, PRNGKey]] ) -> Tuple[ServerState, Mapping[federated_data.ClientId, Any]]: # Compute full-batch gradient at server params on train data. grads_batch_clients = [(cid, cds.padded_batch(grads_batch_hparams), crng) for cid, cds, crng in clients] grads_sum_total, num_sum_total = tree_util.tree_sum( (co for _, co in grads_for_each_client(server_state.params, grads_batch_clients))) server_grads = tree_util.tree_inverse_weight(grads_sum_total, num_sum_total) # Control variant corrected training across clients. client_diagnostics = {} client_num_examples = {cid: len(cds) for cid, cds, _ in clients} batch_clients = [(cid, cds.shuffle_repeat_batch(client_batch_hparams), crng) for cid, cds, crng in clients] shared_input = { 'params': server_state.params, 'opt_state': server_state.opt_state, 'control_variate': server_grads } # Running weighted mean of client updates. delta_params_sum = tree_util.tree_zeros_like(server_state.params) num_examples_sum = 0. for client_id, delta_params in train_for_each_client( shared_input, batch_clients): num_examples = client_num_examples[client_id] delta_params_sum = tree_util.tree_add( delta_params_sum, tree_util.tree_weight(delta_params, num_examples)) num_examples_sum += num_examples client_diagnostics[client_id] = { 'delta_l2_norm': tree_util.tree_l2_norm(delta_params) } mean_delta_params = tree_util.tree_inverse_weight( delta_params_sum, num_examples_sum) server_state = server_update(server_state, server_grads, mean_delta_params) return server_state, client_diagnostics def server_update(server_state, server_grads, mean_delta_params): # Server params uses weighted average of client updates, scaled by the # server_learning_rate. params = jax.tree_util.tree_map( lambda p, q: p - server_learning_rate * q, server_state.params, mean_delta_params) opt_state, _ = base_optimizer.apply(server_grads, server_state.opt_state, server_state.params) return ServerState(params, opt_state) return federated_algorithm.FederatedAlgorithm(init, apply)
from fedjax.algorithms import mime from fedjax.core import client_datasets from fedjax.core import models from fedjax.core import optimizers import jax import jax.numpy as jnp import numpy.testing as npt def per_example_loss(params, batch, rng): del rng return batch['x'] * params['w'] grad_fn = models.grad(per_example_loss) class MimeTest(absltest.TestCase): def test_mime(self): base_optimizer = optimizers.sgd(learning_rate=1.0) train_batch_hparams = client_datasets.ShuffleRepeatBatchHParams( batch_size=2, num_epochs=1, seed=0) grad_batch_hparams = client_datasets.PaddedBatchHParams(batch_size=2) server_learning_rate = 0.2 algorithm = mime.mime(per_example_loss, base_optimizer, train_batch_hparams, grad_batch_hparams, server_learning_rate) with self.subTest('init'): state = algorithm.init({'w': jnp.array(4.)})
def hyp_cluster( per_example_loss: Callable[[Params, BatchExample, PRNGKey], jnp.ndarray], client_optimizer: optimizers.Optimizer, server_optimizer: optimizers.Optimizer, maximization_batch_hparams: client_datasets.PaddedBatchHParams, expectation_batch_hparams: client_datasets.ShuffleRepeatBatchHParams, regularizer: Optional[Callable[[Params], jnp.ndarray]] = None ) -> federated_algorithm.FederatedAlgorithm: """Federated hypothesis-based clustering algorithm. Args: per_example_loss: A function from (params, batch, rng) to a vector of per example loss values. client_optimizer: Client side optimizer. server_optimizer: Server side optimizer. maximization_batch_hparams: Batching hyperparameters for the maximization step. expectation_batch_hparams: Batching hyperparameters for the expectation step. regularizer: Optional regularizer. Returns: A FederatedAlgorithm. A notable difference from other common FederatedAlgorithms such as fed_avg is that `init()` takes a list of `cluster_params`, which can be obtained using `random_init()`, `ModelKMeansInitializer`, or `kmeans_init()` in this module. """ def init(cluster_params: List[Params]) -> ServerState: return ServerState( cluster_params, [server_optimizer.init(params) for params in cluster_params]) # Creating these objects outside apply() can speed up repeated apply() calls. evaluator = models.AverageLossEvaluator(per_example_loss, regularizer) trainer = ClientDeltaTrainer( models.grad(per_example_loss, regularizer), client_optimizer) def apply( server_state: ServerState, clients: Sequence[Tuple[federated_data.ClientId, client_datasets.ClientDataset, PRNGKey]] ) -> Tuple[ServerState, Mapping[federated_data.ClientId, Any]]: # Split RNGs for the 2 steps below. client_rngs = [jax.random.split(rng) for _, _, rng in clients] client_cluster_ids = maximization_step( evaluator=evaluator, cluster_params=server_state.cluster_params, clients=[(client_id, dataset, rng[0]) for (client_id, dataset, _), rng in zip(clients, client_rngs)], batch_hparams=maximization_batch_hparams) cluster_delta_params = expectation_step( trainer=trainer, cluster_params=server_state.cluster_params, client_cluster_ids=client_cluster_ids, clients=[(client_id, dataset, rng[1]) for (client_id, dataset, _), rng in zip(clients, client_rngs)], batch_hparams=expectation_batch_hparams) # Apply delta params for each cluster. cluster_params = [] opt_states = [] for delta_params, opt_state, params in zip(cluster_delta_params, server_state.opt_states, server_state.cluster_params): if delta_params is None: # No examples were observed for this cluster. next_opt_state, next_params = opt_state, params else: next_opt_state, next_params = server_optimizer.apply( delta_params, opt_state, params) cluster_params.append(next_params) opt_states.append(next_opt_state) # TODO(wuke): Other client diagnostics. client_diagnostics = {} for client_id, cluster_id in client_cluster_ids.items(): client_diagnostics[client_id] = {'cluster_id': cluster_id} return ServerState(cluster_params, opt_states), client_diagnostics return federated_algorithm.FederatedAlgorithm(init, apply)