示例#1
0
    def test_expectation_step(self):
        def per_example_loss(params, batch, rng):
            self.assertIsNotNone(rng)
            return jnp.square(params - batch['x'])

        trainer = hyp_cluster.ClientDeltaTrainer(models.grad(per_example_loss),
                                                 optimizers.sgd(0.5))
        batch_hparams = client_datasets.ShuffleRepeatBatchHParams(batch_size=1,
                                                                  num_epochs=5)

        cluster_params = [jnp.array(1.), jnp.array(-1.), jnp.array(3.14)]
        client_cluster_ids = {b'0': 0, b'1': 0, b'2': 1, b'3': 1, b'4': 0}
        # RNGs are not actually used.
        clients = [
            (b'0', client_datasets.ClientDataset({'x': np.array([1.1])}),
             jax.random.PRNGKey(0)),
            (b'1', client_datasets.ClientDataset({'x': np.array([0.9, 0.9])}),
             jax.random.PRNGKey(1)),
            (b'2', client_datasets.ClientDataset({'x': np.array([-1.1])}),
             jax.random.PRNGKey(2)),
            (b'3',
             client_datasets.ClientDataset({'x': np.array([-0.9, -0.9,
                                                           -0.9])}),
             jax.random.PRNGKey(3)),
            (b'4', client_datasets.ClientDataset({'x': np.array([-0.1])}),
             jax.random.PRNGKey(4)),
        ]
        cluster_delta_params = hyp_cluster.expectation_step(
            trainer=trainer,
            cluster_params=cluster_params,
            client_cluster_ids=client_cluster_ids,
            clients=clients,
            batch_hparams=batch_hparams)
        self.assertIsInstance(cluster_delta_params, list)
        self.assertLen(cluster_delta_params, 3)
        npt.assert_allclose(cluster_delta_params[0],
                            (-0.1 + 0.1 * 2 + 1.1) / 4)
        npt.assert_allclose(cluster_delta_params[1], (0.1 - 0.1 * 3) / 4,
                            rtol=1e-6)
        self.assertIsNone(cluster_delta_params[2])
示例#2
0
def mime(
    per_example_loss: Callable[[Params, BatchExample, PRNGKey], jnp.ndarray],
    base_optimizer: optimizers.Optimizer,
    client_batch_hparams: client_datasets.ShuffleRepeatBatchHParams,
    grads_batch_hparams: client_datasets.PaddedBatchHParams,
    server_learning_rate: float,
    regularizer: Optional[Callable[[Params], jnp.ndarray]] = None
) -> federated_algorithm.FederatedAlgorithm:
    """Builds mime.

  Args:
    per_example_loss: A function from (params, batch_example, rng) to a vector
      of loss values for each example in the batch. This is used in both the
      server gradient computation and gradient descent training.
    base_optimizer: Base optimizer to mimic.
    client_batch_hparams: Hyperparameters for batching client dataset for train.
    grads_batch_hparams: Hyperparameters for batching client dataset for server
      gradient computation.
    server_learning_rate: Server learning rate.
    regularizer: Optional regularizer that only depends on params.

  Returns:
    FederatedAlgorithm
  """
    grad_fn = models.grad(per_example_loss, regularizer)
    grads_for_each_client = create_grads_for_each_client(grad_fn)
    train_for_each_client = create_train_for_each_client(
        grad_fn, base_optimizer)

    def init(params: Params) -> ServerState:
        opt_state = base_optimizer.init(params)
        return ServerState(params, opt_state)

    def apply(
        server_state: ServerState,
        clients: Sequence[Tuple[federated_data.ClientId,
                                client_datasets.ClientDataset, PRNGKey]]
    ) -> Tuple[ServerState, Mapping[federated_data.ClientId, Any]]:
        # Compute full-batch gradient at server params on train data.
        grads_batch_clients = [(cid, cds.padded_batch(grads_batch_hparams),
                                crng) for cid, cds, crng in clients]
        grads_sum_total, num_sum_total = tree_util.tree_sum(
            (co for _, co in grads_for_each_client(server_state.params,
                                                   grads_batch_clients)))
        server_grads = tree_util.tree_inverse_weight(grads_sum_total,
                                                     num_sum_total)
        # Control variant corrected training across clients.
        client_diagnostics = {}
        client_num_examples = {cid: len(cds) for cid, cds, _ in clients}
        batch_clients = [(cid, cds.shuffle_repeat_batch(client_batch_hparams),
                          crng) for cid, cds, crng in clients]
        shared_input = {
            'params': server_state.params,
            'opt_state': server_state.opt_state,
            'control_variate': server_grads
        }
        # Running weighted mean of client updates.
        delta_params_sum = tree_util.tree_zeros_like(server_state.params)
        num_examples_sum = 0.
        for client_id, delta_params in train_for_each_client(
                shared_input, batch_clients):
            num_examples = client_num_examples[client_id]
            delta_params_sum = tree_util.tree_add(
                delta_params_sum,
                tree_util.tree_weight(delta_params, num_examples))
            num_examples_sum += num_examples
            client_diagnostics[client_id] = {
                'delta_l2_norm': tree_util.tree_l2_norm(delta_params)
            }
        mean_delta_params = tree_util.tree_inverse_weight(
            delta_params_sum, num_examples_sum)

        server_state = server_update(server_state, server_grads,
                                     mean_delta_params)
        return server_state, client_diagnostics

    def server_update(server_state, server_grads, mean_delta_params):
        # Server params uses weighted average of client updates, scaled by the
        # server_learning_rate.
        params = jax.tree_util.tree_map(
            lambda p, q: p - server_learning_rate * q, server_state.params,
            mean_delta_params)
        opt_state, _ = base_optimizer.apply(server_grads,
                                            server_state.opt_state,
                                            server_state.params)
        return ServerState(params, opt_state)

    return federated_algorithm.FederatedAlgorithm(init, apply)
示例#3
0
from fedjax.algorithms import mime
from fedjax.core import client_datasets
from fedjax.core import models
from fedjax.core import optimizers

import jax
import jax.numpy as jnp
import numpy.testing as npt


def per_example_loss(params, batch, rng):
  del rng
  return batch['x'] * params['w']


grad_fn = models.grad(per_example_loss)


class MimeTest(absltest.TestCase):

  def test_mime(self):
    base_optimizer = optimizers.sgd(learning_rate=1.0)
    train_batch_hparams = client_datasets.ShuffleRepeatBatchHParams(
        batch_size=2, num_epochs=1, seed=0)
    grad_batch_hparams = client_datasets.PaddedBatchHParams(batch_size=2)
    server_learning_rate = 0.2
    algorithm = mime.mime(per_example_loss, base_optimizer, train_batch_hparams,
                          grad_batch_hparams, server_learning_rate)

    with self.subTest('init'):
      state = algorithm.init({'w': jnp.array(4.)})
示例#4
0
def hyp_cluster(
    per_example_loss: Callable[[Params, BatchExample, PRNGKey], jnp.ndarray],
    client_optimizer: optimizers.Optimizer,
    server_optimizer: optimizers.Optimizer,
    maximization_batch_hparams: client_datasets.PaddedBatchHParams,
    expectation_batch_hparams: client_datasets.ShuffleRepeatBatchHParams,
    regularizer: Optional[Callable[[Params], jnp.ndarray]] = None
) -> federated_algorithm.FederatedAlgorithm:
  """Federated hypothesis-based clustering algorithm.

  Args:
    per_example_loss: A function from (params, batch, rng) to a vector of per
      example loss values.
    client_optimizer: Client side optimizer.
    server_optimizer: Server side optimizer.
    maximization_batch_hparams: Batching hyperparameters for the maximization
      step.
    expectation_batch_hparams: Batching hyperparameters for the expectation
      step.
    regularizer: Optional regularizer.

  Returns:
    A FederatedAlgorithm. A notable difference from other common
    FederatedAlgorithms such as fed_avg is that `init()` takes a list of
    `cluster_params`, which can be obtained using `random_init()`,
    `ModelKMeansInitializer`, or `kmeans_init()` in this module.
  """

  def init(cluster_params: List[Params]) -> ServerState:
    return ServerState(
        cluster_params,
        [server_optimizer.init(params) for params in cluster_params])

  # Creating these objects outside apply() can speed up repeated apply() calls.
  evaluator = models.AverageLossEvaluator(per_example_loss, regularizer)
  trainer = ClientDeltaTrainer(
      models.grad(per_example_loss, regularizer), client_optimizer)

  def apply(
      server_state: ServerState,
      clients: Sequence[Tuple[federated_data.ClientId,
                              client_datasets.ClientDataset, PRNGKey]]
  ) -> Tuple[ServerState, Mapping[federated_data.ClientId, Any]]:
    # Split RNGs for the 2 steps below.
    client_rngs = [jax.random.split(rng) for _, _, rng in clients]

    client_cluster_ids = maximization_step(
        evaluator=evaluator,
        cluster_params=server_state.cluster_params,
        clients=[(client_id, dataset, rng[0])
                 for (client_id, dataset, _), rng in zip(clients, client_rngs)],
        batch_hparams=maximization_batch_hparams)

    cluster_delta_params = expectation_step(
        trainer=trainer,
        cluster_params=server_state.cluster_params,
        client_cluster_ids=client_cluster_ids,
        clients=[(client_id, dataset, rng[1])
                 for (client_id, dataset, _), rng in zip(clients, client_rngs)],
        batch_hparams=expectation_batch_hparams)

    # Apply delta params for each cluster.
    cluster_params = []
    opt_states = []
    for delta_params, opt_state, params in zip(cluster_delta_params,
                                               server_state.opt_states,
                                               server_state.cluster_params):
      if delta_params is None:
        # No examples were observed for this cluster.
        next_opt_state, next_params = opt_state, params
      else:
        next_opt_state, next_params = server_optimizer.apply(
            delta_params, opt_state, params)
      cluster_params.append(next_params)
      opt_states.append(next_opt_state)

    # TODO(wuke): Other client diagnostics.
    client_diagnostics = {}
    for client_id, cluster_id in client_cluster_ids.items():
      client_diagnostics[client_id] = {'cluster_id': cluster_id}
    return ServerState(cluster_params, opt_states), client_diagnostics

  return federated_algorithm.FederatedAlgorithm(init, apply)