def apply( server_state: ServerState, clients: Sequence[Tuple[federated_data.ClientId, client_datasets.ClientDataset, PRNGKey]] ) -> Tuple[ServerState, Mapping[federated_data.ClientId, Any]]: client_num_examples = {cid: len(cds) for cid, cds, _ in clients} batch_clients = [(cid, cds.shuffle_repeat_batch(client_batch_hparams), crng) for cid, cds, crng in clients] client_diagnostics = {} # Running weighted mean of client updates. We do this iteratively to avoid # loading all the client outputs into memory since they can be prohibitively # large depending on the model parameters size. delta_params_sum = tree_util.tree_zeros_like(server_state.params) num_examples_sum = 0. for client_id, delta_params in train_for_each_client( server_state.params, batch_clients): num_examples = client_num_examples[client_id] delta_params_sum = tree_util.tree_add( delta_params_sum, tree_util.tree_weight(delta_params, num_examples)) num_examples_sum += num_examples # We record the l2 norm of client updates as an example, but it is not # required for the algorithm. client_diagnostics[client_id] = { 'delta_l2_norm': tree_util.tree_l2_norm(delta_params) } mean_delta_params = tree_util.tree_inverse_weight( delta_params_sum, num_examples_sum) server_state = server_update(server_state, mean_delta_params) return server_state, client_diagnostics
def apply( server_state: ServerState, clients: Sequence[Tuple[federated_data.ClientId, client_datasets.ClientDataset, PRNGKey]] ) -> Tuple[ServerState, Mapping[federated_data.ClientId, Any]]: # α alpha = server_state.domain_weights / jnp.mean( jnp.asarray(server_state.domain_window), axis=0) # First pass to calculate initial domain loss, domain num, and scaling # weight β for each client. This doesn't involve any aggregation at the # server, so this step and training can be a single round of communication. domain_batch_clients = [(cid, cds.padded_batch(domain_batch_hparams), crng) for cid, cds, crng in clients] shared_input = {'params': server_state.params, 'alpha': alpha} # L^k, N^k, β^k client_domain_metrics = dict( domain_metrics_for_each_client(shared_input, domain_batch_clients)) # Train for each client using scaling weights α and β. batch_clients = [] for cid, cds, crng in clients: client_input = { 'rng': crng, 'beta': client_domain_metrics[cid]['beta'] } batch_clients.append( (cid, cds.shuffle_repeat_batch(client_batch_hparams), client_input)) client_diagnostics = {} # Mean delta params across clients. delta_params_sum = tree_util.tree_zeros_like(server_state.params) weight_sum = 0. # w^k for cid, delta_params in train_for_each_client(shared_input, batch_clients): weight = client_domain_metrics[cid]['beta'] delta_params_sum = tree_util.tree_add( delta_params_sum, tree_util.tree_weight(delta_params, weight)) weight_sum += weight client_diagnostics[cid] = { 'delta_l2_norm': tree_util.tree_l2_norm(delta_params) } mean_delta_params = tree_util.tree_inverse_weight( delta_params_sum, weight_sum) # Sum domain metrics across clients. sum_domain_loss = tree_util.tree_sum( d['domain_loss'] for d in client_domain_metrics.values()) sum_domain_num = tree_util.tree_sum( d['domain_num'] for d in client_domain_metrics.values()) server_state = server_update(server_state, mean_delta_params, sum_domain_loss, sum_domain_num) return server_state, client_diagnostics
def apply( server_state: mime.ServerState, clients: Sequence[Tuple[federated_data.ClientId, client_datasets.ClientDataset, PRNGKey]] ) -> Tuple[mime.ServerState, Mapping[federated_data.ClientId, Any]]: # Training across clients using fixed optimizer state. client_diagnostics = {} client_num_examples = {cid: len(cds) for cid, cds, _ in clients} batch_clients = [(cid, cds.shuffle_repeat_batch(client_batch_hparams), crng) for cid, cds, crng in clients] shared_input = { 'params': server_state.params, 'opt_state': server_state.opt_state } # Running weighted mean of client updates. delta_params_sum = tree_util.tree_zeros_like(server_state.params) num_examples_sum = 0. for client_id, delta_params in train_for_each_client( shared_input, batch_clients): num_examples = client_num_examples[client_id] client_diagnostics[client_id] = { 'delta_l2_norm': tree_util.tree_l2_norm(delta_params) } if client_delta_clip_norm is not None: delta_params = tree_util.tree_clip_by_global_norm( delta_params, client_delta_clip_norm) client_diagnostics[client_id]['clipped_delta_l2_norm'] = ( tree_util.tree_l2_norm(delta_params)) client_diagnostics[client_id]['clipped'] = jnp.not_equal( client_diagnostics[client_id]['delta_l2_norm'], client_diagnostics[client_id]['clipped_delta_l2_norm']) delta_params_sum = tree_util.tree_add( delta_params_sum, tree_util.tree_weight(delta_params, num_examples)) num_examples_sum += num_examples mean_delta_params = tree_util.tree_inverse_weight( delta_params_sum, num_examples_sum) # Compute full-batch gradient at server params on train data. grads_batch_clients = [(cid, cds.padded_batch(grads_batch_hparams), crng) for cid, cds, crng in clients] grads_sum_total, num_sum_total = tree_util.tree_sum( (co for _, co in grads_for_each_client(server_state.params, grads_batch_clients))) server_grads = tree_util.tree_inverse_weight(grads_sum_total, num_sum_total) server_state = server_update(server_state, server_grads, mean_delta_params) return server_state, client_diagnostics
def apply( server_state: ServerState, clients: Sequence[Tuple[federated_data.ClientId, client_datasets.ClientDataset, PRNGKey]] ) -> Tuple[ServerState, Mapping[federated_data.ClientId, Any]]: # Compute full-batch gradient at server params on train data. grads_batch_clients = [(cid, cds.padded_batch(grads_batch_hparams), crng) for cid, cds, crng in clients] grads_sum_total, num_sum_total = tree_util.tree_sum( (co for _, co in grads_for_each_client(server_state.params, grads_batch_clients))) server_grads = tree_util.tree_inverse_weight(grads_sum_total, num_sum_total) # Control variant corrected training across clients. client_diagnostics = {} client_num_examples = {cid: len(cds) for cid, cds, _ in clients} batch_clients = [(cid, cds.shuffle_repeat_batch(client_batch_hparams), crng) for cid, cds, crng in clients] shared_input = { 'params': server_state.params, 'opt_state': server_state.opt_state, 'control_variate': server_grads } # Running weighted mean of client updates. delta_params_sum = tree_util.tree_zeros_like(server_state.params) num_examples_sum = 0. for client_id, delta_params in train_for_each_client( shared_input, batch_clients): num_examples = client_num_examples[client_id] delta_params_sum = tree_util.tree_add( delta_params_sum, tree_util.tree_weight(delta_params, num_examples)) num_examples_sum += num_examples client_diagnostics[client_id] = { 'delta_l2_norm': tree_util.tree_l2_norm(delta_params) } mean_delta_params = tree_util.tree_inverse_weight( delta_params_sum, num_examples_sum) server_state = server_update(server_state, server_grads, mean_delta_params) return server_state, client_diagnostics