Beispiel #1
0
 def __init__(self,
              model: models.Model,
              client_optimizer: optimizers.Optimizer,
              regularizer: Optional[Callable[[Params], jnp.ndarray]] = None):
   self._model = model
   self._trainer = ClientParamsTrainer(
       models.model_grad(model, regularizer), client_optimizer)
   self._evaluator = models.AverageLossEvaluator(
       models.model_per_example_loss(model), regularizer)
Beispiel #2
0
  def __init__(self,
               model: models.Model,
               regularizer: Optional[Callable[[Params], jnp.ndarray]] = None):
    """Initializes some reusable components.

    Because we need to make multiple passes over each client during evaluation,
    the number of clients that can be evaluated at once is limited. Therefore
    multiple calls to :meth:`~HypClusterEvaluator.evaluate_clients` are needed
    to evaluate a large federated dataset. We factor out some reusable
    components so that the same computation can be jit compiled.

    Args:
      model: Model being evaluated.
      regularizer: Optional regularizer.
    """
    self._maximization_step_evaluator = models.AverageLossEvaluator(
        models.model_per_example_loss(model), regularizer)
    self._model_evaluator = models.ModelEvaluator(model)
Beispiel #3
0
    def test_maximization_step(self):
        # L1 distance from centers.
        def per_example_loss(params, batch, rng):
            # Randomly flip the center to test rng behavior.
            sign = jax.random.bernoulli(rng) * 2 - 1
            return jnp.abs(params * sign - batch['x'])

        def regularizer(params):
            return 0.01 * jnp.sum(jnp.abs(params))

        evaluator = models.AverageLossEvaluator(per_example_loss, regularizer)
        cluster_params = [jnp.array(0.), jnp.array(-1.), jnp.array(2.)]
        # Batch size is chosen so that we run 1 or 2 batches.
        batch_hparams = client_datasets.PaddedBatchHParams(batch_size=2)
        # Special seeds:
        # - No flip in first 2 steps for all 3 clusters: 0;
        # - Flip all in first 2 steps for all 3 clusters: 16;
        # - No flip then flip all for all 3 clusters: 68;
        # - Flips only cluster 1 in first 2 steps: 106.
        clients = [
            # No flip in first 2 steps for all 3 clusters.
            (b'near0', client_datasets.ClientDataset({'x': np.array([0.1])}),
             jax.random.PRNGKey(0)),
            # Flip all in first 2 steps for all 3 clusters.
            (b'near-1',
             client_datasets.ClientDataset({'x': np.array([0.9, 1.1, 1.3])}),
             jax.random.PRNGKey(16)),
            # No flip then flip all for all 3 clusters.
            (b'near2',
             client_datasets.ClientDataset({'x': np.array([1.9, 2.1, -2.1])}),
             jax.random.PRNGKey(68)),
            # Flips only cluster 1 in first 2 steps.
            (b'near1',
             client_datasets.ClientDataset({'x': np.array([0.9, 1.1, 1.3])}),
             jax.random.PRNGKey(106)),
        ]

        cluster_losses = hyp_cluster._cluster_losses(
            evaluator=evaluator,
            cluster_params=cluster_params,
            clients=clients,
            batch_hparams=batch_hparams)
        self.assertCountEqual(cluster_losses,
                              [b'near0', b'near-1', b'near2', b'near1'])
        npt.assert_allclose(cluster_losses[b'near0'],
                            np.array([0.1, 1.1 + 0.01, 1.9 + 0.02]))
        npt.assert_allclose(cluster_losses[b'near-1'],
                            np.array([1.1, 0.5 / 3 + 0.01, 3.1 + 0.02]))
        npt.assert_allclose(cluster_losses[b'near2'],
                            np.array([6.1 / 3, 9.1 / 3 + 0.01, 0.1 + 0.02]),
                            rtol=1e-6)
        npt.assert_allclose(cluster_losses[b'near1'],
                            np.array([1.1, 0.5 / 3 + 0.01, 0.9 + 0.02]))

        client_cluster_ids = hyp_cluster.maximization_step(
            evaluator=evaluator,
            cluster_params=cluster_params,
            clients=clients,
            batch_hparams=batch_hparams)
        self.assertDictEqual(client_cluster_ids, {
            b'near0': 0,
            b'near-1': 1,
            b'near2': 2,
            b'near1': 1
        })
Beispiel #4
0
def hyp_cluster(
    per_example_loss: Callable[[Params, BatchExample, PRNGKey], jnp.ndarray],
    client_optimizer: optimizers.Optimizer,
    server_optimizer: optimizers.Optimizer,
    maximization_batch_hparams: client_datasets.PaddedBatchHParams,
    expectation_batch_hparams: client_datasets.ShuffleRepeatBatchHParams,
    regularizer: Optional[Callable[[Params], jnp.ndarray]] = None
) -> federated_algorithm.FederatedAlgorithm:
  """Federated hypothesis-based clustering algorithm.

  Args:
    per_example_loss: A function from (params, batch, rng) to a vector of per
      example loss values.
    client_optimizer: Client side optimizer.
    server_optimizer: Server side optimizer.
    maximization_batch_hparams: Batching hyperparameters for the maximization
      step.
    expectation_batch_hparams: Batching hyperparameters for the expectation
      step.
    regularizer: Optional regularizer.

  Returns:
    A FederatedAlgorithm. A notable difference from other common
    FederatedAlgorithms such as fed_avg is that `init()` takes a list of
    `cluster_params`, which can be obtained using `random_init()`,
    `ModelKMeansInitializer`, or `kmeans_init()` in this module.
  """

  def init(cluster_params: List[Params]) -> ServerState:
    return ServerState(
        cluster_params,
        [server_optimizer.init(params) for params in cluster_params])

  # Creating these objects outside apply() can speed up repeated apply() calls.
  evaluator = models.AverageLossEvaluator(per_example_loss, regularizer)
  trainer = ClientDeltaTrainer(
      models.grad(per_example_loss, regularizer), client_optimizer)

  def apply(
      server_state: ServerState,
      clients: Sequence[Tuple[federated_data.ClientId,
                              client_datasets.ClientDataset, PRNGKey]]
  ) -> Tuple[ServerState, Mapping[federated_data.ClientId, Any]]:
    # Split RNGs for the 2 steps below.
    client_rngs = [jax.random.split(rng) for _, _, rng in clients]

    client_cluster_ids = maximization_step(
        evaluator=evaluator,
        cluster_params=server_state.cluster_params,
        clients=[(client_id, dataset, rng[0])
                 for (client_id, dataset, _), rng in zip(clients, client_rngs)],
        batch_hparams=maximization_batch_hparams)

    cluster_delta_params = expectation_step(
        trainer=trainer,
        cluster_params=server_state.cluster_params,
        client_cluster_ids=client_cluster_ids,
        clients=[(client_id, dataset, rng[1])
                 for (client_id, dataset, _), rng in zip(clients, client_rngs)],
        batch_hparams=expectation_batch_hparams)

    # Apply delta params for each cluster.
    cluster_params = []
    opt_states = []
    for delta_params, opt_state, params in zip(cluster_delta_params,
                                               server_state.opt_states,
                                               server_state.cluster_params):
      if delta_params is None:
        # No examples were observed for this cluster.
        next_opt_state, next_params = opt_state, params
      else:
        next_opt_state, next_params = server_optimizer.apply(
            delta_params, opt_state, params)
      cluster_params.append(next_params)
      opt_states.append(next_opt_state)

    # TODO(wuke): Other client diagnostics.
    client_diagnostics = {}
    for client_id, cluster_id in client_cluster_ids.items():
      client_diagnostics[client_id] = {'cluster_id': cluster_id}
    return ServerState(cluster_params, opt_states), client_diagnostics

  return federated_algorithm.FederatedAlgorithm(init, apply)
Beispiel #5
0
    def test_evaluate_per_client_params(self):
        def per_example_loss(params, batch, rng):
            return params + batch['x'] + jax.random.uniform(rng, [])

        def regularizer(params):
            return 0.5 * jnp.sum(jnp.square(params))

        rng_0 = jax.random.PRNGKey(0)
        rng_uniform_00 = jax.random.uniform(jax.random.split(rng_0)[1], [])
        rng_term_0 = rng_uniform_00
        rng_1 = jax.random.PRNGKey(1)
        rng_uniform_10 = jax.random.uniform(jax.random.split(rng_1)[1], [])
        rng_uniform_11 = jax.random.uniform(
            jax.random.split(jax.random.split(rng_1)[0])[1], [])
        rng_term_1 = (rng_uniform_10 * 2 + rng_uniform_11) / 3

        with self.subTest('no mask, no regularizer'):
            clients = [
                (b'0000', [{
                    'x': jnp.array([2, 3])
                }], rng_0, jnp.array(0)),
                (b'1001', [{
                    'x': jnp.array([3, 4])
                }, {
                    'x': jnp.array([5])
                }], rng_1, jnp.array(1)),
            ]
            average_loss = dict(
                models.AverageLossEvaluator(per_example_loss).
                evaluate_per_client_params(clients=clients))
            npt.assert_equal(
                average_loss, {
                    b'0000': np.array(2.5) + rng_term_0,
                    b'1001': np.array(5) + rng_term_1
                })

        with self.subTest('no mask, has regularizer'):
            clients = [
                (b'0000', [{
                    'x': jnp.array([2, 3])
                }], rng_0, jnp.array(0)),
                (b'1001', [{
                    'x': jnp.array([3, 4])
                }, {
                    'x': jnp.array([5])
                }], rng_1, jnp.array(1)),
            ]
            average_loss = dict(
                models.AverageLossEvaluator(
                    per_example_loss,
                    regularizer).evaluate_per_client_params(clients=clients))
            npt.assert_equal(
                average_loss, {
                    b'0000': np.array(2.5) + rng_term_0,
                    b'1001': np.array(5.5) + rng_term_1
                })

        with self.subTest('has mask, no regularizer'):
            clients = [
                (b'0000', [{
                    'x': jnp.array([2, 3])
                }], rng_0, jnp.array(0)),
                (b'1001', [{
                    'x': jnp.array([3, 4])
                }, {
                    'x': jnp.array([5, 10]),
                    '__mask__': jnp.array([True, False])
                }], rng_1, jnp.array(1)),
            ]
            average_loss = dict(
                models.AverageLossEvaluator(per_example_loss).
                evaluate_per_client_params(clients=clients))
            npt.assert_equal(
                average_loss, {
                    b'0000': np.array(2.5) + rng_term_0,
                    b'1001': np.array(5) + rng_term_1
                })