Beispiel #1
0
 def apply(
     server_state: ServerState,
     clients: Sequence[Tuple[federated_data.ClientId,
                             client_datasets.ClientDataset, PRNGKey]]
 ) -> Tuple[ServerState, Mapping[federated_data.ClientId, Any]]:
     client_num_examples = {cid: len(cds) for cid, cds, _ in clients}
     batch_clients = [(cid, cds.shuffle_repeat_batch(client_batch_hparams),
                       crng) for cid, cds, crng in clients]
     client_diagnostics = {}
     # Running weighted mean of client updates. We do this iteratively to avoid
     # loading all the client outputs into memory since they can be prohibitively
     # large depending on the model parameters size.
     delta_params_sum = tree_util.tree_zeros_like(server_state.params)
     num_examples_sum = 0.
     for client_id, delta_params in train_for_each_client(
             server_state.params, batch_clients):
         num_examples = client_num_examples[client_id]
         delta_params_sum = tree_util.tree_add(
             delta_params_sum,
             tree_util.tree_weight(delta_params, num_examples))
         num_examples_sum += num_examples
         # We record the l2 norm of client updates as an example, but it is not
         # required for the algorithm.
         client_diagnostics[client_id] = {
             'delta_l2_norm': tree_util.tree_l2_norm(delta_params)
         }
     mean_delta_params = tree_util.tree_inverse_weight(
         delta_params_sum, num_examples_sum)
     server_state = server_update(server_state, mean_delta_params)
     return server_state, client_diagnostics
Beispiel #2
0
    def apply(
        server_state: ServerState,
        clients: Sequence[Tuple[federated_data.ClientId,
                                client_datasets.ClientDataset, PRNGKey]]
    ) -> Tuple[ServerState, Mapping[federated_data.ClientId, Any]]:
        # α
        alpha = server_state.domain_weights / jnp.mean(
            jnp.asarray(server_state.domain_window), axis=0)
        # First pass to calculate initial domain loss, domain num, and scaling
        # weight β for each client. This doesn't involve any aggregation at the
        # server, so this step and training can be a single round of communication.
        domain_batch_clients = [(cid, cds.padded_batch(domain_batch_hparams),
                                 crng) for cid, cds, crng in clients]
        shared_input = {'params': server_state.params, 'alpha': alpha}
        # L^k, N^k, β^k
        client_domain_metrics = dict(
            domain_metrics_for_each_client(shared_input, domain_batch_clients))
        # Train for each client using scaling weights α and β.
        batch_clients = []
        for cid, cds, crng in clients:
            client_input = {
                'rng': crng,
                'beta': client_domain_metrics[cid]['beta']
            }
            batch_clients.append(
                (cid, cds.shuffle_repeat_batch(client_batch_hparams),
                 client_input))

        client_diagnostics = {}
        # Mean delta params across clients.
        delta_params_sum = tree_util.tree_zeros_like(server_state.params)
        weight_sum = 0.
        # w^k
        for cid, delta_params in train_for_each_client(shared_input,
                                                       batch_clients):
            weight = client_domain_metrics[cid]['beta']
            delta_params_sum = tree_util.tree_add(
                delta_params_sum, tree_util.tree_weight(delta_params, weight))
            weight_sum += weight
            client_diagnostics[cid] = {
                'delta_l2_norm': tree_util.tree_l2_norm(delta_params)
            }
        mean_delta_params = tree_util.tree_inverse_weight(
            delta_params_sum, weight_sum)
        # Sum domain metrics across clients.
        sum_domain_loss = tree_util.tree_sum(
            d['domain_loss'] for d in client_domain_metrics.values())
        sum_domain_num = tree_util.tree_sum(
            d['domain_num'] for d in client_domain_metrics.values())
        server_state = server_update(server_state, mean_delta_params,
                                     sum_domain_loss, sum_domain_num)
        return server_state, client_diagnostics
Beispiel #3
0
    def apply(
        server_state: mime.ServerState,
        clients: Sequence[Tuple[federated_data.ClientId,
                                client_datasets.ClientDataset, PRNGKey]]
    ) -> Tuple[mime.ServerState, Mapping[federated_data.ClientId, Any]]:
        # Training across clients using fixed optimizer state.
        client_diagnostics = {}
        client_num_examples = {cid: len(cds) for cid, cds, _ in clients}
        batch_clients = [(cid, cds.shuffle_repeat_batch(client_batch_hparams),
                          crng) for cid, cds, crng in clients]
        shared_input = {
            'params': server_state.params,
            'opt_state': server_state.opt_state
        }
        # Running weighted mean of client updates.
        delta_params_sum = tree_util.tree_zeros_like(server_state.params)
        num_examples_sum = 0.
        for client_id, delta_params in train_for_each_client(
                shared_input, batch_clients):
            num_examples = client_num_examples[client_id]
            client_diagnostics[client_id] = {
                'delta_l2_norm': tree_util.tree_l2_norm(delta_params)
            }
            if client_delta_clip_norm is not None:
                delta_params = tree_util.tree_clip_by_global_norm(
                    delta_params, client_delta_clip_norm)
                client_diagnostics[client_id]['clipped_delta_l2_norm'] = (
                    tree_util.tree_l2_norm(delta_params))
                client_diagnostics[client_id]['clipped'] = jnp.not_equal(
                    client_diagnostics[client_id]['delta_l2_norm'],
                    client_diagnostics[client_id]['clipped_delta_l2_norm'])
            delta_params_sum = tree_util.tree_add(
                delta_params_sum,
                tree_util.tree_weight(delta_params, num_examples))
            num_examples_sum += num_examples
        mean_delta_params = tree_util.tree_inverse_weight(
            delta_params_sum, num_examples_sum)

        # Compute full-batch gradient at server params on train data.
        grads_batch_clients = [(cid, cds.padded_batch(grads_batch_hparams),
                                crng) for cid, cds, crng in clients]
        grads_sum_total, num_sum_total = tree_util.tree_sum(
            (co for _, co in grads_for_each_client(server_state.params,
                                                   grads_batch_clients)))
        server_grads = tree_util.tree_inverse_weight(grads_sum_total,
                                                     num_sum_total)

        server_state = server_update(server_state, server_grads,
                                     mean_delta_params)
        return server_state, client_diagnostics
Beispiel #4
0
    def apply(
        server_state: ServerState,
        clients: Sequence[Tuple[federated_data.ClientId,
                                client_datasets.ClientDataset, PRNGKey]]
    ) -> Tuple[ServerState, Mapping[federated_data.ClientId, Any]]:
        # Compute full-batch gradient at server params on train data.
        grads_batch_clients = [(cid, cds.padded_batch(grads_batch_hparams),
                                crng) for cid, cds, crng in clients]
        grads_sum_total, num_sum_total = tree_util.tree_sum(
            (co for _, co in grads_for_each_client(server_state.params,
                                                   grads_batch_clients)))
        server_grads = tree_util.tree_inverse_weight(grads_sum_total,
                                                     num_sum_total)
        # Control variant corrected training across clients.
        client_diagnostics = {}
        client_num_examples = {cid: len(cds) for cid, cds, _ in clients}
        batch_clients = [(cid, cds.shuffle_repeat_batch(client_batch_hparams),
                          crng) for cid, cds, crng in clients]
        shared_input = {
            'params': server_state.params,
            'opt_state': server_state.opt_state,
            'control_variate': server_grads
        }
        # Running weighted mean of client updates.
        delta_params_sum = tree_util.tree_zeros_like(server_state.params)
        num_examples_sum = 0.
        for client_id, delta_params in train_for_each_client(
                shared_input, batch_clients):
            num_examples = client_num_examples[client_id]
            delta_params_sum = tree_util.tree_add(
                delta_params_sum,
                tree_util.tree_weight(delta_params, num_examples))
            num_examples_sum += num_examples
            client_diagnostics[client_id] = {
                'delta_l2_norm': tree_util.tree_l2_norm(delta_params)
            }
        mean_delta_params = tree_util.tree_inverse_weight(
            delta_params_sum, num_examples_sum)

        server_state = server_update(server_state, server_grads,
                                     mean_delta_params)
        return server_state, client_diagnostics