def apply( server_state: ServerState, clients: Sequence[Tuple[federated_data.ClientId, client_datasets.ClientDataset, PRNGKey]] ) -> Tuple[ServerState, Mapping[federated_data.ClientId, Any]]: client_num_examples = {cid: len(cds) for cid, cds, _ in clients} batch_clients = [(cid, cds.shuffle_repeat_batch(client_batch_hparams), crng) for cid, cds, crng in clients] client_diagnostics = {} # Running weighted mean of client updates. We do this iteratively to avoid # loading all the client outputs into memory since they can be prohibitively # large depending on the model parameters size. delta_params_sum = tree_util.tree_zeros_like(server_state.params) num_examples_sum = 0. for client_id, delta_params in train_for_each_client( server_state.params, batch_clients): num_examples = client_num_examples[client_id] delta_params_sum = tree_util.tree_add( delta_params_sum, tree_util.tree_weight(delta_params, num_examples)) num_examples_sum += num_examples # We record the l2 norm of client updates as an example, but it is not # required for the algorithm. client_diagnostics[client_id] = { 'delta_l2_norm': tree_util.tree_l2_norm(delta_params) } mean_delta_params = tree_util.tree_inverse_weight( delta_params_sum, num_examples_sum) server_state = server_update(server_state, mean_delta_params) return server_state, client_diagnostics
def test_tree_weight(self): pytree = { 'x': jnp.array([[[4, 5]], [[1, 1]]]), 'y': jnp.array([[3], [1]]), } weight = 2.0 weight_pytree = tree_util.tree_weight(pytree, weight) self.assertAllEqual(weight_pytree['x'], [[[8.0, 10.0]], [[2.0, 2.0]]]) self.assertAllEqual(weight_pytree['y'], [[6.0], [2.0]])
def apply( server_state: ServerState, clients: Sequence[Tuple[federated_data.ClientId, client_datasets.ClientDataset, PRNGKey]] ) -> Tuple[ServerState, Mapping[federated_data.ClientId, Any]]: # α alpha = server_state.domain_weights / jnp.mean( jnp.asarray(server_state.domain_window), axis=0) # First pass to calculate initial domain loss, domain num, and scaling # weight β for each client. This doesn't involve any aggregation at the # server, so this step and training can be a single round of communication. domain_batch_clients = [(cid, cds.padded_batch(domain_batch_hparams), crng) for cid, cds, crng in clients] shared_input = {'params': server_state.params, 'alpha': alpha} # L^k, N^k, β^k client_domain_metrics = dict( domain_metrics_for_each_client(shared_input, domain_batch_clients)) # Train for each client using scaling weights α and β. batch_clients = [] for cid, cds, crng in clients: client_input = { 'rng': crng, 'beta': client_domain_metrics[cid]['beta'] } batch_clients.append( (cid, cds.shuffle_repeat_batch(client_batch_hparams), client_input)) client_diagnostics = {} # Mean delta params across clients. delta_params_sum = tree_util.tree_zeros_like(server_state.params) weight_sum = 0. # w^k for cid, delta_params in train_for_each_client(shared_input, batch_clients): weight = client_domain_metrics[cid]['beta'] delta_params_sum = tree_util.tree_add( delta_params_sum, tree_util.tree_weight(delta_params, weight)) weight_sum += weight client_diagnostics[cid] = { 'delta_l2_norm': tree_util.tree_l2_norm(delta_params) } mean_delta_params = tree_util.tree_inverse_weight( delta_params_sum, weight_sum) # Sum domain metrics across clients. sum_domain_loss = tree_util.tree_sum( d['domain_loss'] for d in client_domain_metrics.values()) sum_domain_num = tree_util.tree_sum( d['domain_num'] for d in client_domain_metrics.values()) server_state = server_update(server_state, mean_delta_params, sum_domain_loss, sum_domain_num) return server_state, client_diagnostics
def client_step(client_step_state, batch): rng, use_rng = jax.random.split(client_step_state['rng']) grads = grad_fn(client_step_state['params'], batch, use_rng) num = jnp.sum(batch[client_datasets.EXAMPLE_MASK_KEY]) grads_sum = tree_util.tree_add(tree_util.tree_weight(grads, num), client_step_state['grads_sum']) next_client_step_state = { 'params': client_step_state['params'], 'rng': rng, 'num_sum': client_step_state['num_sum'] + num, 'grads_sum': grads_sum } return next_client_step_state
def apply( server_state: mime.ServerState, clients: Sequence[Tuple[federated_data.ClientId, client_datasets.ClientDataset, PRNGKey]] ) -> Tuple[mime.ServerState, Mapping[federated_data.ClientId, Any]]: # Training across clients using fixed optimizer state. client_diagnostics = {} client_num_examples = {cid: len(cds) for cid, cds, _ in clients} batch_clients = [(cid, cds.shuffle_repeat_batch(client_batch_hparams), crng) for cid, cds, crng in clients] shared_input = { 'params': server_state.params, 'opt_state': server_state.opt_state } # Running weighted mean of client updates. delta_params_sum = tree_util.tree_zeros_like(server_state.params) num_examples_sum = 0. for client_id, delta_params in train_for_each_client( shared_input, batch_clients): num_examples = client_num_examples[client_id] client_diagnostics[client_id] = { 'delta_l2_norm': tree_util.tree_l2_norm(delta_params) } if client_delta_clip_norm is not None: delta_params = tree_util.tree_clip_by_global_norm( delta_params, client_delta_clip_norm) client_diagnostics[client_id]['clipped_delta_l2_norm'] = ( tree_util.tree_l2_norm(delta_params)) client_diagnostics[client_id]['clipped'] = jnp.not_equal( client_diagnostics[client_id]['delta_l2_norm'], client_diagnostics[client_id]['clipped_delta_l2_norm']) delta_params_sum = tree_util.tree_add( delta_params_sum, tree_util.tree_weight(delta_params, num_examples)) num_examples_sum += num_examples mean_delta_params = tree_util.tree_inverse_weight( delta_params_sum, num_examples_sum) # Compute full-batch gradient at server params on train data. grads_batch_clients = [(cid, cds.padded_batch(grads_batch_hparams), crng) for cid, cds, crng in clients] grads_sum_total, num_sum_total = tree_util.tree_sum( (co for _, co in grads_for_each_client(server_state.params, grads_batch_clients))) server_grads = tree_util.tree_inverse_weight(grads_sum_total, num_sum_total) server_state = server_update(server_state, server_grads, mean_delta_params) return server_state, client_diagnostics
def apply( server_state: ServerState, clients: Sequence[Tuple[federated_data.ClientId, client_datasets.ClientDataset, PRNGKey]] ) -> Tuple[ServerState, Mapping[federated_data.ClientId, Any]]: # Compute full-batch gradient at server params on train data. grads_batch_clients = [(cid, cds.padded_batch(grads_batch_hparams), crng) for cid, cds, crng in clients] grads_sum_total, num_sum_total = tree_util.tree_sum( (co for _, co in grads_for_each_client(server_state.params, grads_batch_clients))) server_grads = tree_util.tree_inverse_weight(grads_sum_total, num_sum_total) # Control variant corrected training across clients. client_diagnostics = {} client_num_examples = {cid: len(cds) for cid, cds, _ in clients} batch_clients = [(cid, cds.shuffle_repeat_batch(client_batch_hparams), crng) for cid, cds, crng in clients] shared_input = { 'params': server_state.params, 'opt_state': server_state.opt_state, 'control_variate': server_grads } # Running weighted mean of client updates. delta_params_sum = tree_util.tree_zeros_like(server_state.params) num_examples_sum = 0. for client_id, delta_params in train_for_each_client( shared_input, batch_clients): num_examples = client_num_examples[client_id] delta_params_sum = tree_util.tree_add( delta_params_sum, tree_util.tree_weight(delta_params, num_examples)) num_examples_sum += num_examples client_diagnostics[client_id] = { 'delta_l2_norm': tree_util.tree_l2_norm(delta_params) } mean_delta_params = tree_util.tree_inverse_weight( delta_params_sum, num_examples_sum) server_state = server_update(server_state, server_grads, mean_delta_params) return server_state, client_diagnostics
def expectation_step(trainer: ClientDeltaTrainer, cluster_params: List[Params], client_cluster_ids: Dict[federated_data.ClientId, int], clients: Sequence[Tuple[federated_data.ClientId, client_datasets.ClientDataset, PRNGKey]], batch_hparams: client_datasets.ShuffleRepeatBatchHParams): """Updates each cluster's params using average delta from clients in this cluster.""" # Train each client starting from its corresponding cluster params, and # calculate weighted average of delta params within each cluster. num_examples = { client_id: len(dataset) for client_id, dataset, _ in clients } cluster_delta_params_sum = [ jax.tree_util.tree_map(jnp.zeros_like, params) for params in cluster_params ] cluster_num_examples_sum = [0 for _ in cluster_params] for client_id, delta_params in trainer.train_per_client_params([ (client_id, dataset.shuffle_repeat_batch(batch_hparams), rng, cluster_params[client_cluster_ids[client_id]]) for client_id, dataset, rng in clients ]): cluster_id = client_cluster_ids[client_id] cluster_delta_params_sum[cluster_id] = tree_util.tree_add( cluster_delta_params_sum[cluster_id], tree_util.tree_weight(delta_params, num_examples[client_id])) cluster_num_examples_sum[cluster_id] += num_examples[client_id] # Weighted average delta params, or None if no examples were seen for this # cluster. cluster_delta_params = [] for delta_params_sum, num_examples_sum in zip(cluster_delta_params_sum, cluster_num_examples_sum): if num_examples_sum > 0: cluster_delta_params.append( tree_util.tree_inverse_weight(delta_params_sum, num_examples_sum)) else: cluster_delta_params.append(None) return cluster_delta_params