def fake_algorithm(): """Counts the number of clients and sums up the 'x' feature.""" def init(): # num_clients, sum_values return 0, 0 def apply(state, clients): num_clients, sum_values = state for _, dataset, _ in clients: num_clients += 1 for x in dataset.all_examples()['x']: sum_values += x state = num_clients, sum_values return state, None return federated_algorithm.FederatedAlgorithm(init, apply)
def mime( per_example_loss: Callable[[Params, BatchExample, PRNGKey], jnp.ndarray], base_optimizer: optimizers.Optimizer, client_batch_hparams: client_datasets.ShuffleRepeatBatchHParams, grads_batch_hparams: client_datasets.PaddedBatchHParams, server_learning_rate: float, regularizer: Optional[Callable[[Params], jnp.ndarray]] = None ) -> federated_algorithm.FederatedAlgorithm: """Builds mime. Args: per_example_loss: A function from (params, batch_example, rng) to a vector of loss values for each example in the batch. This is used in both the server gradient computation and gradient descent training. base_optimizer: Base optimizer to mimic. client_batch_hparams: Hyperparameters for batching client dataset for train. grads_batch_hparams: Hyperparameters for batching client dataset for server gradient computation. server_learning_rate: Server learning rate. regularizer: Optional regularizer that only depends on params. Returns: FederatedAlgorithm """ grad_fn = models.grad(per_example_loss, regularizer) grads_for_each_client = create_grads_for_each_client(grad_fn) train_for_each_client = create_train_for_each_client( grad_fn, base_optimizer) def init(params: Params) -> ServerState: opt_state = base_optimizer.init(params) return ServerState(params, opt_state) def apply( server_state: ServerState, clients: Sequence[Tuple[federated_data.ClientId, client_datasets.ClientDataset, PRNGKey]] ) -> Tuple[ServerState, Mapping[federated_data.ClientId, Any]]: # Compute full-batch gradient at server params on train data. grads_batch_clients = [(cid, cds.padded_batch(grads_batch_hparams), crng) for cid, cds, crng in clients] grads_sum_total, num_sum_total = tree_util.tree_sum( (co for _, co in grads_for_each_client(server_state.params, grads_batch_clients))) server_grads = tree_util.tree_inverse_weight(grads_sum_total, num_sum_total) # Control variant corrected training across clients. client_diagnostics = {} client_num_examples = {cid: len(cds) for cid, cds, _ in clients} batch_clients = [(cid, cds.shuffle_repeat_batch(client_batch_hparams), crng) for cid, cds, crng in clients] shared_input = { 'params': server_state.params, 'opt_state': server_state.opt_state, 'control_variate': server_grads } # Running weighted mean of client updates. delta_params_sum = tree_util.tree_zeros_like(server_state.params) num_examples_sum = 0. for client_id, delta_params in train_for_each_client( shared_input, batch_clients): num_examples = client_num_examples[client_id] delta_params_sum = tree_util.tree_add( delta_params_sum, tree_util.tree_weight(delta_params, num_examples)) num_examples_sum += num_examples client_diagnostics[client_id] = { 'delta_l2_norm': tree_util.tree_l2_norm(delta_params) } mean_delta_params = tree_util.tree_inverse_weight( delta_params_sum, num_examples_sum) server_state = server_update(server_state, server_grads, mean_delta_params) return server_state, client_diagnostics def server_update(server_state, server_grads, mean_delta_params): # Server params uses weighted average of client updates, scaled by the # server_learning_rate. params = jax.tree_util.tree_map( lambda p, q: p - server_learning_rate * q, server_state.params, mean_delta_params) opt_state, _ = base_optimizer.apply(server_grads, server_state.opt_state, server_state.params) return ServerState(params, opt_state) return federated_algorithm.FederatedAlgorithm(init, apply)
def federated_averaging( grad_fn: Callable[[Params, BatchExample, PRNGKey], Grads], client_optimizer: optimizers.Optimizer, server_optimizer: optimizers.Optimizer, client_batch_hparams: client_datasets.ShuffleRepeatBatchHParams ) -> federated_algorithm.FederatedAlgorithm: """Builds federated averaging. Args: grad_fn: A function from (params, batch_example, rng) to gradients. This can be created with :func:`fedjax.core.model.model_grad`. client_optimizer: Optimizer for local client training. server_optimizer: Optimizer for server update. client_batch_hparams: Hyperparameters for batching client dataset for train. Returns: FederatedAlgorithm """ train_for_each_client = create_train_for_each_client( grad_fn, client_optimizer) def init(params: Params) -> ServerState: opt_state = server_optimizer.init(params) return ServerState(params, opt_state) def apply( server_state: ServerState, clients: Sequence[Tuple[federated_data.ClientId, client_datasets.ClientDataset, PRNGKey]] ) -> Tuple[ServerState, Mapping[federated_data.ClientId, Any]]: client_num_examples = {cid: len(cds) for cid, cds, _ in clients} batch_clients = [(cid, cds.shuffle_repeat_batch(client_batch_hparams), crng) for cid, cds, crng in clients] client_diagnostics = {} # Running weighted mean of client updates. We do this iteratively to avoid # loading all the client outputs into memory since they can be prohibitively # large depending on the model parameters size. delta_params_sum = tree_util.tree_zeros_like(server_state.params) num_examples_sum = 0. for client_id, delta_params in train_for_each_client( server_state.params, batch_clients): num_examples = client_num_examples[client_id] delta_params_sum = tree_util.tree_add( delta_params_sum, tree_util.tree_weight(delta_params, num_examples)) num_examples_sum += num_examples # We record the l2 norm of client updates as an example, but it is not # required for the algorithm. client_diagnostics[client_id] = { 'delta_l2_norm': tree_util.tree_l2_norm(delta_params) } mean_delta_params = tree_util.tree_inverse_weight( delta_params_sum, num_examples_sum) server_state = server_update(server_state, mean_delta_params) return server_state, client_diagnostics def server_update(server_state, mean_delta_params): opt_state, params = server_optimizer.apply(mean_delta_params, server_state.opt_state, server_state.params) return ServerState(params, opt_state) return federated_algorithm.FederatedAlgorithm(init, apply)
def hyp_cluster( per_example_loss: Callable[[Params, BatchExample, PRNGKey], jnp.ndarray], client_optimizer: optimizers.Optimizer, server_optimizer: optimizers.Optimizer, maximization_batch_hparams: client_datasets.PaddedBatchHParams, expectation_batch_hparams: client_datasets.ShuffleRepeatBatchHParams, regularizer: Optional[Callable[[Params], jnp.ndarray]] = None ) -> federated_algorithm.FederatedAlgorithm: """Federated hypothesis-based clustering algorithm. Args: per_example_loss: A function from (params, batch, rng) to a vector of per example loss values. client_optimizer: Client side optimizer. server_optimizer: Server side optimizer. maximization_batch_hparams: Batching hyperparameters for the maximization step. expectation_batch_hparams: Batching hyperparameters for the expectation step. regularizer: Optional regularizer. Returns: A FederatedAlgorithm. A notable difference from other common FederatedAlgorithms such as fed_avg is that `init()` takes a list of `cluster_params`, which can be obtained using `random_init()`, `ModelKMeansInitializer`, or `kmeans_init()` in this module. """ def init(cluster_params: List[Params]) -> ServerState: return ServerState( cluster_params, [server_optimizer.init(params) for params in cluster_params]) # Creating these objects outside apply() can speed up repeated apply() calls. evaluator = models.AverageLossEvaluator(per_example_loss, regularizer) trainer = ClientDeltaTrainer( models.grad(per_example_loss, regularizer), client_optimizer) def apply( server_state: ServerState, clients: Sequence[Tuple[federated_data.ClientId, client_datasets.ClientDataset, PRNGKey]] ) -> Tuple[ServerState, Mapping[federated_data.ClientId, Any]]: # Split RNGs for the 2 steps below. client_rngs = [jax.random.split(rng) for _, _, rng in clients] client_cluster_ids = maximization_step( evaluator=evaluator, cluster_params=server_state.cluster_params, clients=[(client_id, dataset, rng[0]) for (client_id, dataset, _), rng in zip(clients, client_rngs)], batch_hparams=maximization_batch_hparams) cluster_delta_params = expectation_step( trainer=trainer, cluster_params=server_state.cluster_params, client_cluster_ids=client_cluster_ids, clients=[(client_id, dataset, rng[1]) for (client_id, dataset, _), rng in zip(clients, client_rngs)], batch_hparams=expectation_batch_hparams) # Apply delta params for each cluster. cluster_params = [] opt_states = [] for delta_params, opt_state, params in zip(cluster_delta_params, server_state.opt_states, server_state.cluster_params): if delta_params is None: # No examples were observed for this cluster. next_opt_state, next_params = opt_state, params else: next_opt_state, next_params = server_optimizer.apply( delta_params, opt_state, params) cluster_params.append(next_params) opt_states.append(next_opt_state) # TODO(wuke): Other client diagnostics. client_diagnostics = {} for client_id, cluster_id in client_cluster_ids.items(): client_diagnostics[client_id] = {'cluster_id': cluster_id} return ServerState(cluster_params, opt_states), client_diagnostics return federated_algorithm.FederatedAlgorithm(init, apply)
def agnostic_federated_averaging( per_example_loss: Callable[[Params, BatchExample, PRNGKey], jnp.ndarray], client_optimizer: optimizers.Optimizer, server_optimizer: optimizers.Optimizer, client_batch_hparams: client_datasets.ShuffleRepeatBatchHParams, domain_batch_hparams: client_datasets.PaddedBatchHParams, init_domain_weights: Sequence[float], domain_learning_rate: float, domain_algorithm: str = 'eg', domain_window_size: int = 1, init_domain_window: Optional[Sequence[float]] = None, regularizer: Optional[Callable[[Params], jnp.ndarray]] = None ) -> federated_algorithm.FederatedAlgorithm: """Builds agnostic federated averaging. Agnostic federated averaging requires input :class:`fedjax.core.client_datasets.ClientDataset` examples to contain a feature named "domain_id", which stores the integer domain id in [0, num_domains). For example, for Stack Overflow, each example post can be either a question or an answer, so there are two possible domain ids (question = 0; answer = 1). Args: per_example_loss: A function from (params, batch_example, rng) to a vector of loss values for each example in the batch. This is used in both the domain metrics computation and gradient descent training. client_optimizer: Optimizer for local client training. server_optimizer: Optimizer for server update. client_batch_hparams: Hyperparameters for client dataset for training. domain_batch_hparams: Hyperparameters for client dataset domain metrics calculation. init_domain_weights: Initial weights per domain that must sum to 1. domain_learning_rate: Learning rate for domain weight update. domain_algorithm: Algorithm used to update domain weights each round. One of 'eg', 'none'. domain_window_size: Size of sliding window keeping track of number of examples per domain over multiple rounds. init_domain_window: Initial values for domain window. Defaults to ones. regularizer: Optional regularizer that only depends on params. Returns: FederatedAlgorithm. Raises: ValueError: If ``init_domain_weights`` does not sum to 1 or if ``init_domain_weights`` and ``init_domain_window`` are unequal lengths. """ if abs(sum(init_domain_weights) - 1) > 1e-6: raise ValueError('init_domain_weights must sum to approximately 1.') if init_domain_window is None: init_domain_window = jnp.ones_like(init_domain_weights) if len(init_domain_weights) != len(init_domain_window): raise ValueError( f'init_domain_weights and init_domain_window must be equal lengths.' f' {len(init_domain_weights)} != {len(init_domain_window)}') num_domains = len(init_domain_weights) domain_metrics_for_each_client = create_domain_metrics_for_each_client( per_example_loss, num_domains) train_for_each_client = create_train_for_each_client( per_example_loss, client_optimizer, num_domains, regularizer) def init(params: Params) -> ServerState: opt_state = server_optimizer.init(params) domain_weights = jnp.array(init_domain_weights) domain_window = [jnp.array(init_domain_window)] * domain_window_size return ServerState(params, opt_state, domain_weights, domain_window) def apply( server_state: ServerState, clients: Sequence[Tuple[federated_data.ClientId, client_datasets.ClientDataset, PRNGKey]] ) -> Tuple[ServerState, Mapping[federated_data.ClientId, Any]]: # α alpha = server_state.domain_weights / jnp.mean( jnp.asarray(server_state.domain_window), axis=0) # First pass to calculate initial domain loss, domain num, and scaling # weight β for each client. This doesn't involve any aggregation at the # server, so this step and training can be a single round of communication. domain_batch_clients = [(cid, cds.padded_batch(domain_batch_hparams), crng) for cid, cds, crng in clients] shared_input = {'params': server_state.params, 'alpha': alpha} # L^k, N^k, β^k client_domain_metrics = dict( domain_metrics_for_each_client(shared_input, domain_batch_clients)) # Train for each client using scaling weights α and β. batch_clients = [] for cid, cds, crng in clients: client_input = { 'rng': crng, 'beta': client_domain_metrics[cid]['beta'] } batch_clients.append( (cid, cds.shuffle_repeat_batch(client_batch_hparams), client_input)) client_diagnostics = {} # Mean delta params across clients. delta_params_sum = tree_util.tree_zeros_like(server_state.params) weight_sum = 0. # w^k for cid, delta_params in train_for_each_client(shared_input, batch_clients): weight = client_domain_metrics[cid]['beta'] delta_params_sum = tree_util.tree_add( delta_params_sum, tree_util.tree_weight(delta_params, weight)) weight_sum += weight client_diagnostics[cid] = { 'delta_l2_norm': tree_util.tree_l2_norm(delta_params) } mean_delta_params = tree_util.tree_inverse_weight( delta_params_sum, weight_sum) # Sum domain metrics across clients. sum_domain_loss = tree_util.tree_sum( d['domain_loss'] for d in client_domain_metrics.values()) sum_domain_num = tree_util.tree_sum( d['domain_num'] for d in client_domain_metrics.values()) server_state = server_update(server_state, mean_delta_params, sum_domain_loss, sum_domain_num) return server_state, client_diagnostics def server_update(server_state, mean_delta_params, sum_domain_loss, sum_domain_num): opt_state, params = server_optimizer.apply(mean_delta_params, server_state.opt_state, server_state.params) mean_domain_loss = util.safe_div(sum_domain_loss, sum_domain_num) domain_weights = update_domain_weights(server_state.domain_weights, mean_domain_loss, domain_learning_rate, domain_algorithm) domain_window = server_state.domain_window[1:] + [sum_domain_num] return ServerState(params, opt_state, domain_weights, domain_window) return federated_algorithm.FederatedAlgorithm(init, apply)
def fed_prox(per_example_loss: Callable[[Params, BatchExample, PRNGKey], jnp.ndarray], client_optimizer: optimizers.Optimizer, server_optimizer: optimizers.Optimizer, client_batch_hparams: client_datasets.ShuffleRepeatBatchHParams, proximal_weight: float) -> federated_algorithm.FederatedAlgorithm: """Builds FedProx. Args: per_example_loss: A function from (params, batch, rng) to a vector of per example loss values. This will be combined with a proximal term based on server params weighted by proximal_weight. client_optimizer: Optimizer for local client training. server_optimizer: Optimizer for server update. client_batch_hparams: Hyperparameters for batching client dataset for train. proximal_weight: Weight for proximal term. 0 weight is FedAvg. Returns: FederatedAlgorithm """ def fed_prox_loss(params, server_params, batch, rng): example_loss = per_example_loss(params, batch, rng) proximal_loss = 0.5 * proximal_weight * tree_util.tree_l2_squared( jax.tree_util.tree_map(lambda a, b: a - b, server_params, params)) return jnp.mean(example_loss + proximal_loss) grad_fn = jax.grad(fed_prox_loss) train_for_each_client = create_train_for_each_client(grad_fn, client_optimizer) def init(params: Params) -> ServerState: opt_state = server_optimizer.init(params) return ServerState(params, opt_state) def apply( server_state: ServerState, clients: Sequence[Tuple[federated_data.ClientId, client_datasets.ClientDataset, PRNGKey]] ) -> Tuple[ServerState, Mapping[federated_data.ClientId, Any]]: client_num_examples = {cid: len(cds) for cid, cds, _ in clients} batch_clients = [(cid, cds.shuffle_repeat_batch(client_batch_hparams), crng) for cid, cds, crng in clients] client_diagnostics = {} # Running weighted mean of client updates. We do this iteratively to avoid # loading all the client outputs into memory since they can be prohibitively # large depending on the model parameters size. delta_params_sum = tree_util.tree_zeros_like(server_state.params) num_examples_sum = 0. for client_id, delta_params in train_for_each_client( server_state.params, batch_clients): num_examples = client_num_examples[client_id] delta_params_sum = tree_util.tree_add( delta_params_sum, tree_util.tree_weight(delta_params, num_examples)) num_examples_sum += num_examples # We record the l2 norm of client updates as an example, but it is not # required for the algorithm. client_diagnostics[client_id] = { 'delta_l2_norm': tree_util.tree_l2_norm(delta_params) } mean_delta_params = tree_util.tree_inverse_weight(delta_params_sum, num_examples_sum) server_state = server_update(server_state, mean_delta_params) return server_state, client_diagnostics def server_update(server_state, mean_delta_params): opt_state, params = server_optimizer.apply(mean_delta_params, server_state.opt_state, server_state.params) return ServerState(params, opt_state) return federated_algorithm.FederatedAlgorithm(init, apply)