예제 #1
0
파일: fed_avg.py 프로젝트: google/fedjax
 def apply(
     server_state: ServerState,
     clients: Sequence[Tuple[federated_data.ClientId,
                             client_datasets.ClientDataset, PRNGKey]]
 ) -> Tuple[ServerState, Mapping[federated_data.ClientId, Any]]:
     client_num_examples = {cid: len(cds) for cid, cds, _ in clients}
     batch_clients = [(cid, cds.shuffle_repeat_batch(client_batch_hparams),
                       crng) for cid, cds, crng in clients]
     client_diagnostics = {}
     # Running weighted mean of client updates. We do this iteratively to avoid
     # loading all the client outputs into memory since they can be prohibitively
     # large depending on the model parameters size.
     delta_params_sum = tree_util.tree_zeros_like(server_state.params)
     num_examples_sum = 0.
     for client_id, delta_params in train_for_each_client(
             server_state.params, batch_clients):
         num_examples = client_num_examples[client_id]
         delta_params_sum = tree_util.tree_add(
             delta_params_sum,
             tree_util.tree_weight(delta_params, num_examples))
         num_examples_sum += num_examples
         # We record the l2 norm of client updates as an example, but it is not
         # required for the algorithm.
         client_diagnostics[client_id] = {
             'delta_l2_norm': tree_util.tree_l2_norm(delta_params)
         }
     mean_delta_params = tree_util.tree_inverse_weight(
         delta_params_sum, num_examples_sum)
     server_state = server_update(server_state, mean_delta_params)
     return server_state, client_diagnostics
예제 #2
0
 def test_tree_weight(self):
     pytree = {
         'x': jnp.array([[[4, 5]], [[1, 1]]]),
         'y': jnp.array([[3], [1]]),
     }
     weight = 2.0
     weight_pytree = tree_util.tree_weight(pytree, weight)
     self.assertAllEqual(weight_pytree['x'], [[[8.0, 10.0]], [[2.0, 2.0]]])
     self.assertAllEqual(weight_pytree['y'], [[6.0], [2.0]])
예제 #3
0
    def apply(
        server_state: ServerState,
        clients: Sequence[Tuple[federated_data.ClientId,
                                client_datasets.ClientDataset, PRNGKey]]
    ) -> Tuple[ServerState, Mapping[federated_data.ClientId, Any]]:
        # α
        alpha = server_state.domain_weights / jnp.mean(
            jnp.asarray(server_state.domain_window), axis=0)
        # First pass to calculate initial domain loss, domain num, and scaling
        # weight β for each client. This doesn't involve any aggregation at the
        # server, so this step and training can be a single round of communication.
        domain_batch_clients = [(cid, cds.padded_batch(domain_batch_hparams),
                                 crng) for cid, cds, crng in clients]
        shared_input = {'params': server_state.params, 'alpha': alpha}
        # L^k, N^k, β^k
        client_domain_metrics = dict(
            domain_metrics_for_each_client(shared_input, domain_batch_clients))
        # Train for each client using scaling weights α and β.
        batch_clients = []
        for cid, cds, crng in clients:
            client_input = {
                'rng': crng,
                'beta': client_domain_metrics[cid]['beta']
            }
            batch_clients.append(
                (cid, cds.shuffle_repeat_batch(client_batch_hparams),
                 client_input))

        client_diagnostics = {}
        # Mean delta params across clients.
        delta_params_sum = tree_util.tree_zeros_like(server_state.params)
        weight_sum = 0.
        # w^k
        for cid, delta_params in train_for_each_client(shared_input,
                                                       batch_clients):
            weight = client_domain_metrics[cid]['beta']
            delta_params_sum = tree_util.tree_add(
                delta_params_sum, tree_util.tree_weight(delta_params, weight))
            weight_sum += weight
            client_diagnostics[cid] = {
                'delta_l2_norm': tree_util.tree_l2_norm(delta_params)
            }
        mean_delta_params = tree_util.tree_inverse_weight(
            delta_params_sum, weight_sum)
        # Sum domain metrics across clients.
        sum_domain_loss = tree_util.tree_sum(
            d['domain_loss'] for d in client_domain_metrics.values())
        sum_domain_num = tree_util.tree_sum(
            d['domain_num'] for d in client_domain_metrics.values())
        server_state = server_update(server_state, mean_delta_params,
                                     sum_domain_loss, sum_domain_num)
        return server_state, client_diagnostics
예제 #4
0
 def client_step(client_step_state, batch):
     rng, use_rng = jax.random.split(client_step_state['rng'])
     grads = grad_fn(client_step_state['params'], batch, use_rng)
     num = jnp.sum(batch[client_datasets.EXAMPLE_MASK_KEY])
     grads_sum = tree_util.tree_add(tree_util.tree_weight(grads, num),
                                    client_step_state['grads_sum'])
     next_client_step_state = {
         'params': client_step_state['params'],
         'rng': rng,
         'num_sum': client_step_state['num_sum'] + num,
         'grads_sum': grads_sum
     }
     return next_client_step_state
예제 #5
0
    def apply(
        server_state: mime.ServerState,
        clients: Sequence[Tuple[federated_data.ClientId,
                                client_datasets.ClientDataset, PRNGKey]]
    ) -> Tuple[mime.ServerState, Mapping[federated_data.ClientId, Any]]:
        # Training across clients using fixed optimizer state.
        client_diagnostics = {}
        client_num_examples = {cid: len(cds) for cid, cds, _ in clients}
        batch_clients = [(cid, cds.shuffle_repeat_batch(client_batch_hparams),
                          crng) for cid, cds, crng in clients]
        shared_input = {
            'params': server_state.params,
            'opt_state': server_state.opt_state
        }
        # Running weighted mean of client updates.
        delta_params_sum = tree_util.tree_zeros_like(server_state.params)
        num_examples_sum = 0.
        for client_id, delta_params in train_for_each_client(
                shared_input, batch_clients):
            num_examples = client_num_examples[client_id]
            client_diagnostics[client_id] = {
                'delta_l2_norm': tree_util.tree_l2_norm(delta_params)
            }
            if client_delta_clip_norm is not None:
                delta_params = tree_util.tree_clip_by_global_norm(
                    delta_params, client_delta_clip_norm)
                client_diagnostics[client_id]['clipped_delta_l2_norm'] = (
                    tree_util.tree_l2_norm(delta_params))
                client_diagnostics[client_id]['clipped'] = jnp.not_equal(
                    client_diagnostics[client_id]['delta_l2_norm'],
                    client_diagnostics[client_id]['clipped_delta_l2_norm'])
            delta_params_sum = tree_util.tree_add(
                delta_params_sum,
                tree_util.tree_weight(delta_params, num_examples))
            num_examples_sum += num_examples
        mean_delta_params = tree_util.tree_inverse_weight(
            delta_params_sum, num_examples_sum)

        # Compute full-batch gradient at server params on train data.
        grads_batch_clients = [(cid, cds.padded_batch(grads_batch_hparams),
                                crng) for cid, cds, crng in clients]
        grads_sum_total, num_sum_total = tree_util.tree_sum(
            (co for _, co in grads_for_each_client(server_state.params,
                                                   grads_batch_clients)))
        server_grads = tree_util.tree_inverse_weight(grads_sum_total,
                                                     num_sum_total)

        server_state = server_update(server_state, server_grads,
                                     mean_delta_params)
        return server_state, client_diagnostics
예제 #6
0
    def apply(
        server_state: ServerState,
        clients: Sequence[Tuple[federated_data.ClientId,
                                client_datasets.ClientDataset, PRNGKey]]
    ) -> Tuple[ServerState, Mapping[federated_data.ClientId, Any]]:
        # Compute full-batch gradient at server params on train data.
        grads_batch_clients = [(cid, cds.padded_batch(grads_batch_hparams),
                                crng) for cid, cds, crng in clients]
        grads_sum_total, num_sum_total = tree_util.tree_sum(
            (co for _, co in grads_for_each_client(server_state.params,
                                                   grads_batch_clients)))
        server_grads = tree_util.tree_inverse_weight(grads_sum_total,
                                                     num_sum_total)
        # Control variant corrected training across clients.
        client_diagnostics = {}
        client_num_examples = {cid: len(cds) for cid, cds, _ in clients}
        batch_clients = [(cid, cds.shuffle_repeat_batch(client_batch_hparams),
                          crng) for cid, cds, crng in clients]
        shared_input = {
            'params': server_state.params,
            'opt_state': server_state.opt_state,
            'control_variate': server_grads
        }
        # Running weighted mean of client updates.
        delta_params_sum = tree_util.tree_zeros_like(server_state.params)
        num_examples_sum = 0.
        for client_id, delta_params in train_for_each_client(
                shared_input, batch_clients):
            num_examples = client_num_examples[client_id]
            delta_params_sum = tree_util.tree_add(
                delta_params_sum,
                tree_util.tree_weight(delta_params, num_examples))
            num_examples_sum += num_examples
            client_diagnostics[client_id] = {
                'delta_l2_norm': tree_util.tree_l2_norm(delta_params)
            }
        mean_delta_params = tree_util.tree_inverse_weight(
            delta_params_sum, num_examples_sum)

        server_state = server_update(server_state, server_grads,
                                     mean_delta_params)
        return server_state, client_diagnostics
예제 #7
0
def expectation_step(trainer: ClientDeltaTrainer, cluster_params: List[Params],
                     client_cluster_ids: Dict[federated_data.ClientId, int],
                     clients: Sequence[Tuple[federated_data.ClientId,
                                             client_datasets.ClientDataset,
                                             PRNGKey]],
                     batch_hparams: client_datasets.ShuffleRepeatBatchHParams):
    """Updates each cluster's params using average delta from clients in this cluster."""
    # Train each client starting from its corresponding cluster params, and
    # calculate weighted average of delta params within each cluster.
    num_examples = {
        client_id: len(dataset)
        for client_id, dataset, _ in clients
    }
    cluster_delta_params_sum = [
        jax.tree_util.tree_map(jnp.zeros_like, params)
        for params in cluster_params
    ]
    cluster_num_examples_sum = [0 for _ in cluster_params]
    for client_id, delta_params in trainer.train_per_client_params([
        (client_id, dataset.shuffle_repeat_batch(batch_hparams), rng,
         cluster_params[client_cluster_ids[client_id]])
            for client_id, dataset, rng in clients
    ]):
        cluster_id = client_cluster_ids[client_id]
        cluster_delta_params_sum[cluster_id] = tree_util.tree_add(
            cluster_delta_params_sum[cluster_id],
            tree_util.tree_weight(delta_params, num_examples[client_id]))
        cluster_num_examples_sum[cluster_id] += num_examples[client_id]
    # Weighted average delta params, or None if no examples were seen for this
    # cluster.
    cluster_delta_params = []
    for delta_params_sum, num_examples_sum in zip(cluster_delta_params_sum,
                                                  cluster_num_examples_sum):
        if num_examples_sum > 0:
            cluster_delta_params.append(
                tree_util.tree_inverse_weight(delta_params_sum,
                                              num_examples_sum))
        else:
            cluster_delta_params.append(None)
    return cluster_delta_params