def fake_algorithm():
  """Counts the number of clients and sums up the 'x' feature."""

  def init():
    # num_clients, sum_values
    return 0, 0

  def apply(state, clients):
    num_clients, sum_values = state
    for _, dataset, _ in clients:
      num_clients += 1
      for x in dataset.all_examples()['x']:
        sum_values += x
    state = num_clients, sum_values
    return state, None

  return federated_algorithm.FederatedAlgorithm(init, apply)
Example #2
0
def mime(
    per_example_loss: Callable[[Params, BatchExample, PRNGKey], jnp.ndarray],
    base_optimizer: optimizers.Optimizer,
    client_batch_hparams: client_datasets.ShuffleRepeatBatchHParams,
    grads_batch_hparams: client_datasets.PaddedBatchHParams,
    server_learning_rate: float,
    regularizer: Optional[Callable[[Params], jnp.ndarray]] = None
) -> federated_algorithm.FederatedAlgorithm:
    """Builds mime.

  Args:
    per_example_loss: A function from (params, batch_example, rng) to a vector
      of loss values for each example in the batch. This is used in both the
      server gradient computation and gradient descent training.
    base_optimizer: Base optimizer to mimic.
    client_batch_hparams: Hyperparameters for batching client dataset for train.
    grads_batch_hparams: Hyperparameters for batching client dataset for server
      gradient computation.
    server_learning_rate: Server learning rate.
    regularizer: Optional regularizer that only depends on params.

  Returns:
    FederatedAlgorithm
  """
    grad_fn = models.grad(per_example_loss, regularizer)
    grads_for_each_client = create_grads_for_each_client(grad_fn)
    train_for_each_client = create_train_for_each_client(
        grad_fn, base_optimizer)

    def init(params: Params) -> ServerState:
        opt_state = base_optimizer.init(params)
        return ServerState(params, opt_state)

    def apply(
        server_state: ServerState,
        clients: Sequence[Tuple[federated_data.ClientId,
                                client_datasets.ClientDataset, PRNGKey]]
    ) -> Tuple[ServerState, Mapping[federated_data.ClientId, Any]]:
        # Compute full-batch gradient at server params on train data.
        grads_batch_clients = [(cid, cds.padded_batch(grads_batch_hparams),
                                crng) for cid, cds, crng in clients]
        grads_sum_total, num_sum_total = tree_util.tree_sum(
            (co for _, co in grads_for_each_client(server_state.params,
                                                   grads_batch_clients)))
        server_grads = tree_util.tree_inverse_weight(grads_sum_total,
                                                     num_sum_total)
        # Control variant corrected training across clients.
        client_diagnostics = {}
        client_num_examples = {cid: len(cds) for cid, cds, _ in clients}
        batch_clients = [(cid, cds.shuffle_repeat_batch(client_batch_hparams),
                          crng) for cid, cds, crng in clients]
        shared_input = {
            'params': server_state.params,
            'opt_state': server_state.opt_state,
            'control_variate': server_grads
        }
        # Running weighted mean of client updates.
        delta_params_sum = tree_util.tree_zeros_like(server_state.params)
        num_examples_sum = 0.
        for client_id, delta_params in train_for_each_client(
                shared_input, batch_clients):
            num_examples = client_num_examples[client_id]
            delta_params_sum = tree_util.tree_add(
                delta_params_sum,
                tree_util.tree_weight(delta_params, num_examples))
            num_examples_sum += num_examples
            client_diagnostics[client_id] = {
                'delta_l2_norm': tree_util.tree_l2_norm(delta_params)
            }
        mean_delta_params = tree_util.tree_inverse_weight(
            delta_params_sum, num_examples_sum)

        server_state = server_update(server_state, server_grads,
                                     mean_delta_params)
        return server_state, client_diagnostics

    def server_update(server_state, server_grads, mean_delta_params):
        # Server params uses weighted average of client updates, scaled by the
        # server_learning_rate.
        params = jax.tree_util.tree_map(
            lambda p, q: p - server_learning_rate * q, server_state.params,
            mean_delta_params)
        opt_state, _ = base_optimizer.apply(server_grads,
                                            server_state.opt_state,
                                            server_state.params)
        return ServerState(params, opt_state)

    return federated_algorithm.FederatedAlgorithm(init, apply)
Example #3
0
def federated_averaging(
    grad_fn: Callable[[Params, BatchExample, PRNGKey],
                      Grads], client_optimizer: optimizers.Optimizer,
    server_optimizer: optimizers.Optimizer,
    client_batch_hparams: client_datasets.ShuffleRepeatBatchHParams
) -> federated_algorithm.FederatedAlgorithm:
    """Builds federated averaging.

  Args:
    grad_fn: A function from (params, batch_example, rng) to gradients.
      This can be created with :func:`fedjax.core.model.model_grad`.
    client_optimizer: Optimizer for local client training.
    server_optimizer: Optimizer for server update.
    client_batch_hparams: Hyperparameters for batching client dataset for train.

  Returns:
    FederatedAlgorithm
  """
    train_for_each_client = create_train_for_each_client(
        grad_fn, client_optimizer)

    def init(params: Params) -> ServerState:
        opt_state = server_optimizer.init(params)
        return ServerState(params, opt_state)

    def apply(
        server_state: ServerState,
        clients: Sequence[Tuple[federated_data.ClientId,
                                client_datasets.ClientDataset, PRNGKey]]
    ) -> Tuple[ServerState, Mapping[federated_data.ClientId, Any]]:
        client_num_examples = {cid: len(cds) for cid, cds, _ in clients}
        batch_clients = [(cid, cds.shuffle_repeat_batch(client_batch_hparams),
                          crng) for cid, cds, crng in clients]
        client_diagnostics = {}
        # Running weighted mean of client updates. We do this iteratively to avoid
        # loading all the client outputs into memory since they can be prohibitively
        # large depending on the model parameters size.
        delta_params_sum = tree_util.tree_zeros_like(server_state.params)
        num_examples_sum = 0.
        for client_id, delta_params in train_for_each_client(
                server_state.params, batch_clients):
            num_examples = client_num_examples[client_id]
            delta_params_sum = tree_util.tree_add(
                delta_params_sum,
                tree_util.tree_weight(delta_params, num_examples))
            num_examples_sum += num_examples
            # We record the l2 norm of client updates as an example, but it is not
            # required for the algorithm.
            client_diagnostics[client_id] = {
                'delta_l2_norm': tree_util.tree_l2_norm(delta_params)
            }
        mean_delta_params = tree_util.tree_inverse_weight(
            delta_params_sum, num_examples_sum)
        server_state = server_update(server_state, mean_delta_params)
        return server_state, client_diagnostics

    def server_update(server_state, mean_delta_params):
        opt_state, params = server_optimizer.apply(mean_delta_params,
                                                   server_state.opt_state,
                                                   server_state.params)
        return ServerState(params, opt_state)

    return federated_algorithm.FederatedAlgorithm(init, apply)
Example #4
0
def hyp_cluster(
    per_example_loss: Callable[[Params, BatchExample, PRNGKey], jnp.ndarray],
    client_optimizer: optimizers.Optimizer,
    server_optimizer: optimizers.Optimizer,
    maximization_batch_hparams: client_datasets.PaddedBatchHParams,
    expectation_batch_hparams: client_datasets.ShuffleRepeatBatchHParams,
    regularizer: Optional[Callable[[Params], jnp.ndarray]] = None
) -> federated_algorithm.FederatedAlgorithm:
  """Federated hypothesis-based clustering algorithm.

  Args:
    per_example_loss: A function from (params, batch, rng) to a vector of per
      example loss values.
    client_optimizer: Client side optimizer.
    server_optimizer: Server side optimizer.
    maximization_batch_hparams: Batching hyperparameters for the maximization
      step.
    expectation_batch_hparams: Batching hyperparameters for the expectation
      step.
    regularizer: Optional regularizer.

  Returns:
    A FederatedAlgorithm. A notable difference from other common
    FederatedAlgorithms such as fed_avg is that `init()` takes a list of
    `cluster_params`, which can be obtained using `random_init()`,
    `ModelKMeansInitializer`, or `kmeans_init()` in this module.
  """

  def init(cluster_params: List[Params]) -> ServerState:
    return ServerState(
        cluster_params,
        [server_optimizer.init(params) for params in cluster_params])

  # Creating these objects outside apply() can speed up repeated apply() calls.
  evaluator = models.AverageLossEvaluator(per_example_loss, regularizer)
  trainer = ClientDeltaTrainer(
      models.grad(per_example_loss, regularizer), client_optimizer)

  def apply(
      server_state: ServerState,
      clients: Sequence[Tuple[federated_data.ClientId,
                              client_datasets.ClientDataset, PRNGKey]]
  ) -> Tuple[ServerState, Mapping[federated_data.ClientId, Any]]:
    # Split RNGs for the 2 steps below.
    client_rngs = [jax.random.split(rng) for _, _, rng in clients]

    client_cluster_ids = maximization_step(
        evaluator=evaluator,
        cluster_params=server_state.cluster_params,
        clients=[(client_id, dataset, rng[0])
                 for (client_id, dataset, _), rng in zip(clients, client_rngs)],
        batch_hparams=maximization_batch_hparams)

    cluster_delta_params = expectation_step(
        trainer=trainer,
        cluster_params=server_state.cluster_params,
        client_cluster_ids=client_cluster_ids,
        clients=[(client_id, dataset, rng[1])
                 for (client_id, dataset, _), rng in zip(clients, client_rngs)],
        batch_hparams=expectation_batch_hparams)

    # Apply delta params for each cluster.
    cluster_params = []
    opt_states = []
    for delta_params, opt_state, params in zip(cluster_delta_params,
                                               server_state.opt_states,
                                               server_state.cluster_params):
      if delta_params is None:
        # No examples were observed for this cluster.
        next_opt_state, next_params = opt_state, params
      else:
        next_opt_state, next_params = server_optimizer.apply(
            delta_params, opt_state, params)
      cluster_params.append(next_params)
      opt_states.append(next_opt_state)

    # TODO(wuke): Other client diagnostics.
    client_diagnostics = {}
    for client_id, cluster_id in client_cluster_ids.items():
      client_diagnostics[client_id] = {'cluster_id': cluster_id}
    return ServerState(cluster_params, opt_states), client_diagnostics

  return federated_algorithm.FederatedAlgorithm(init, apply)
Example #5
0
def agnostic_federated_averaging(
    per_example_loss: Callable[[Params, BatchExample, PRNGKey], jnp.ndarray],
    client_optimizer: optimizers.Optimizer,
    server_optimizer: optimizers.Optimizer,
    client_batch_hparams: client_datasets.ShuffleRepeatBatchHParams,
    domain_batch_hparams: client_datasets.PaddedBatchHParams,
    init_domain_weights: Sequence[float],
    domain_learning_rate: float,
    domain_algorithm: str = 'eg',
    domain_window_size: int = 1,
    init_domain_window: Optional[Sequence[float]] = None,
    regularizer: Optional[Callable[[Params], jnp.ndarray]] = None
) -> federated_algorithm.FederatedAlgorithm:
    """Builds agnostic federated averaging.

  Agnostic federated averaging requires input
  :class:`fedjax.core.client_datasets.ClientDataset` examples to contain
  a feature named "domain_id", which stores the integer domain id in
  [0, num_domains).
  For example, for Stack Overflow, each example post can be either a question or
  an answer, so there are two possible domain ids (question = 0; answer = 1).

  Args:
    per_example_loss: A function from (params, batch_example, rng) to a vector
      of loss values for each example in the batch. This is used in both the
      domain metrics computation and gradient descent training.
    client_optimizer: Optimizer for local client training.
    server_optimizer: Optimizer for server update.
    client_batch_hparams: Hyperparameters for client dataset for training.
    domain_batch_hparams: Hyperparameters for client dataset domain metrics
      calculation.
    init_domain_weights: Initial weights per domain that must sum to 1.
    domain_learning_rate: Learning rate for domain weight update.
    domain_algorithm: Algorithm used to update domain weights each round. One of
      'eg', 'none'.
    domain_window_size: Size of sliding window keeping track of number of
      examples per domain over multiple rounds.
    init_domain_window: Initial values for domain window. Defaults to ones.
    regularizer: Optional regularizer that only depends on params.

  Returns:
    FederatedAlgorithm.

  Raises:
    ValueError: If ``init_domain_weights`` does not sum to 1 or if
      ``init_domain_weights`` and ``init_domain_window`` are unequal lengths.
  """
    if abs(sum(init_domain_weights) - 1) > 1e-6:
        raise ValueError('init_domain_weights must sum to approximately 1.')

    if init_domain_window is None:
        init_domain_window = jnp.ones_like(init_domain_weights)

    if len(init_domain_weights) != len(init_domain_window):
        raise ValueError(
            f'init_domain_weights and init_domain_window must be equal lengths.'
            f' {len(init_domain_weights)} != {len(init_domain_window)}')

    num_domains = len(init_domain_weights)
    domain_metrics_for_each_client = create_domain_metrics_for_each_client(
        per_example_loss, num_domains)
    train_for_each_client = create_train_for_each_client(
        per_example_loss, client_optimizer, num_domains, regularizer)

    def init(params: Params) -> ServerState:
        opt_state = server_optimizer.init(params)
        domain_weights = jnp.array(init_domain_weights)
        domain_window = [jnp.array(init_domain_window)] * domain_window_size
        return ServerState(params, opt_state, domain_weights, domain_window)

    def apply(
        server_state: ServerState,
        clients: Sequence[Tuple[federated_data.ClientId,
                                client_datasets.ClientDataset, PRNGKey]]
    ) -> Tuple[ServerState, Mapping[federated_data.ClientId, Any]]:
        # α
        alpha = server_state.domain_weights / jnp.mean(
            jnp.asarray(server_state.domain_window), axis=0)
        # First pass to calculate initial domain loss, domain num, and scaling
        # weight β for each client. This doesn't involve any aggregation at the
        # server, so this step and training can be a single round of communication.
        domain_batch_clients = [(cid, cds.padded_batch(domain_batch_hparams),
                                 crng) for cid, cds, crng in clients]
        shared_input = {'params': server_state.params, 'alpha': alpha}
        # L^k, N^k, β^k
        client_domain_metrics = dict(
            domain_metrics_for_each_client(shared_input, domain_batch_clients))
        # Train for each client using scaling weights α and β.
        batch_clients = []
        for cid, cds, crng in clients:
            client_input = {
                'rng': crng,
                'beta': client_domain_metrics[cid]['beta']
            }
            batch_clients.append(
                (cid, cds.shuffle_repeat_batch(client_batch_hparams),
                 client_input))

        client_diagnostics = {}
        # Mean delta params across clients.
        delta_params_sum = tree_util.tree_zeros_like(server_state.params)
        weight_sum = 0.
        # w^k
        for cid, delta_params in train_for_each_client(shared_input,
                                                       batch_clients):
            weight = client_domain_metrics[cid]['beta']
            delta_params_sum = tree_util.tree_add(
                delta_params_sum, tree_util.tree_weight(delta_params, weight))
            weight_sum += weight
            client_diagnostics[cid] = {
                'delta_l2_norm': tree_util.tree_l2_norm(delta_params)
            }
        mean_delta_params = tree_util.tree_inverse_weight(
            delta_params_sum, weight_sum)
        # Sum domain metrics across clients.
        sum_domain_loss = tree_util.tree_sum(
            d['domain_loss'] for d in client_domain_metrics.values())
        sum_domain_num = tree_util.tree_sum(
            d['domain_num'] for d in client_domain_metrics.values())
        server_state = server_update(server_state, mean_delta_params,
                                     sum_domain_loss, sum_domain_num)
        return server_state, client_diagnostics

    def server_update(server_state, mean_delta_params, sum_domain_loss,
                      sum_domain_num):
        opt_state, params = server_optimizer.apply(mean_delta_params,
                                                   server_state.opt_state,
                                                   server_state.params)
        mean_domain_loss = util.safe_div(sum_domain_loss, sum_domain_num)
        domain_weights = update_domain_weights(server_state.domain_weights,
                                               mean_domain_loss,
                                               domain_learning_rate,
                                               domain_algorithm)
        domain_window = server_state.domain_window[1:] + [sum_domain_num]
        return ServerState(params, opt_state, domain_weights, domain_window)

    return federated_algorithm.FederatedAlgorithm(init, apply)
Example #6
0
def fed_prox(per_example_loss: Callable[[Params, BatchExample, PRNGKey],
                                        jnp.ndarray],
             client_optimizer: optimizers.Optimizer,
             server_optimizer: optimizers.Optimizer,
             client_batch_hparams: client_datasets.ShuffleRepeatBatchHParams,
             proximal_weight: float) -> federated_algorithm.FederatedAlgorithm:
  """Builds FedProx.

  Args:
    per_example_loss: A function from (params, batch, rng) to a vector of per
      example loss values. This will be combined with a proximal term based on
      server params weighted by proximal_weight.
    client_optimizer: Optimizer for local client training.
    server_optimizer: Optimizer for server update.
    client_batch_hparams: Hyperparameters for batching client dataset for train.
    proximal_weight: Weight for proximal term. 0 weight is FedAvg.

  Returns:
    FederatedAlgorithm
  """

  def fed_prox_loss(params, server_params, batch, rng):
    example_loss = per_example_loss(params, batch, rng)
    proximal_loss = 0.5 * proximal_weight * tree_util.tree_l2_squared(
        jax.tree_util.tree_map(lambda a, b: a - b, server_params, params))
    return jnp.mean(example_loss + proximal_loss)

  grad_fn = jax.grad(fed_prox_loss)
  train_for_each_client = create_train_for_each_client(grad_fn,
                                                       client_optimizer)

  def init(params: Params) -> ServerState:
    opt_state = server_optimizer.init(params)
    return ServerState(params, opt_state)

  def apply(
      server_state: ServerState,
      clients: Sequence[Tuple[federated_data.ClientId,
                              client_datasets.ClientDataset, PRNGKey]]
  ) -> Tuple[ServerState, Mapping[federated_data.ClientId, Any]]:
    client_num_examples = {cid: len(cds) for cid, cds, _ in clients}
    batch_clients = [(cid, cds.shuffle_repeat_batch(client_batch_hparams), crng)
                     for cid, cds, crng in clients]
    client_diagnostics = {}
    # Running weighted mean of client updates. We do this iteratively to avoid
    # loading all the client outputs into memory since they can be prohibitively
    # large depending on the model parameters size.
    delta_params_sum = tree_util.tree_zeros_like(server_state.params)
    num_examples_sum = 0.
    for client_id, delta_params in train_for_each_client(
        server_state.params, batch_clients):
      num_examples = client_num_examples[client_id]
      delta_params_sum = tree_util.tree_add(
          delta_params_sum, tree_util.tree_weight(delta_params, num_examples))
      num_examples_sum += num_examples
      # We record the l2 norm of client updates as an example, but it is not
      # required for the algorithm.
      client_diagnostics[client_id] = {
          'delta_l2_norm': tree_util.tree_l2_norm(delta_params)
      }
    mean_delta_params = tree_util.tree_inverse_weight(delta_params_sum,
                                                      num_examples_sum)
    server_state = server_update(server_state, mean_delta_params)
    return server_state, client_diagnostics

  def server_update(server_state, mean_delta_params):
    opt_state, params = server_optimizer.apply(mean_delta_params,
                                               server_state.opt_state,
                                               server_state.params)
    return ServerState(params, opt_state)

  return federated_algorithm.FederatedAlgorithm(init, apply)