예제 #1
0
def evaluate_multiple_clients(
    federated_data: FederatedData,
    client_ids: List[str],
    model: Model,
    params: Params,
    client_data_hparams: Optional[dataset_util.ClientDataHParams] = None
) -> Iterator[MetricResults]:
    """Evaluates model over input clients' datasets.

  Args:
    federated_data: Federated data separated per client.
    client_ids: Ids of clients to evaluate.
    model: Model implementation.
    params: Pytree of model parameters to be evaluated.
    client_data_hparams: Hyperparameters for client dataset preparation.

  Yields:
    Ordered mapping of metrics per client or empty mapping if a client's dataset
      is emtpy.
  """
    if client_data_hparams is not None:
        federated_data = federated_data.preprocess(
            lambda ds: dataset_util.preprocess_tf_dataset(
                ds, client_data_hparams))
    for _, client_dataset in prefetch.PrefetchClientDatasetsIterator(
            federated_data, client_ids):
        yield evaluate_single_client(client_dataset, model, params)
예제 #2
0
    def test_evaluate_single_client(self):
        num_clients = 10
        num_classes = 10
        num_examples = 100
        client_data_hparams = dataset_util.ClientDataHParams(batch_size=20,
                                                             num_epochs=3)
        data, model = test_util.create_toy_example(num_clients=num_clients,
                                                   num_clusters=4,
                                                   num_classes=num_classes,
                                                   num_examples=num_examples,
                                                   seed=0)
        rng_seq = hk.PRNGSequence(0)
        init_params = model.init_params(next(rng_seq))
        dataset = dataset_util.preprocess_tf_dataset(
            dataset_util.create_tf_dataset_for_clients(data),
            client_data_hparams)

        with self.subTest('tf dataset'):
            init_metrics = evaluation_util.evaluate_single_client(
                dataset=dataset, model=model, params=init_params)
            self.assertLess(0.0, init_metrics['loss'])

        with self.subTest('plain iterator'):
            init_metrics = evaluation_util.evaluate_single_client(
                dataset=dataset.as_numpy_iterator(),
                model=model,
                params=init_params)
            self.assertLess(0.0, init_metrics['loss'])
예제 #3
0
    def test_preprocess_tf_dataset(self, hparams, expected_num_batches):
        x = np.arange(10 * 2).reshape((10, 2))
        numpy_data = collections.OrderedDict(x=x, y=x, z=x)
        dataset = tf.data.Dataset.from_tensor_slices(numpy_data)

        batches = list(dataset_util.preprocess_tf_dataset(dataset, hparams))

        self.assertLen(batches, expected_num_batches)
예제 #4
0
def train_multiple_clients(
    federated_data: FederatedData, client_ids: List[str],
    client_trainer: ClientTrainer[T], init_client_trainer_state: T,
    rng_seq: PRNGSequence,
    client_data_hparams: dataset_util.ClientDataHParams) -> Iterator[Any]:
  """Trains separate model for each client and records client updates.

  Depending on the value of `--fedjax_experimental_disable_parallel`, this will
  be run either sequentially or in parallel across local devices via `jax.pmap`.

  Args:
    federated_data: Federated data separated per client.
    client_ids: Ids of clients to train.
    client_trainer: ClientTrainer instance.
    init_client_trainer_state: Initial client trainer state. This will typically
      be derived from algorithm state before calling `train_multiple_clients`.
    rng_seq: Random key generator.
    client_data_hparams: Hyperparameters for client dataset preparation.

  Yields:
    Output of client trainer that is typically just an updated version of the
      input `init_client_trainer_state`. However, output is flexible.
  """
  should_run_parallel = (
      FLAGS.fedjax_experimental_disable_parallel == 'false' or
      (FLAGS.fedjax_experimental_disable_parallel == 'auto' and
       jax.local_device_count() > 1))
  if should_run_parallel:
    yield from _train_multiple_clients_parallel(
        federated_data=federated_data,
        client_ids=client_ids,
        client_trainer=client_trainer,
        init_client_trainer_state=init_client_trainer_state,
        rng_seq=rng_seq,
        client_data_hparams=client_data_hparams)
    return

  preprocessed_federated_data = federated_data.preprocess(
      lambda ds: dataset_util.preprocess_tf_dataset(ds, client_data_hparams))
  for _, client_dataset in prefetch.PrefetchClientDatasetsIterator(
      preprocessed_federated_data, client_ids):
    examples = zip(dataset_util.iterate(client_dataset), rng_seq)
    client_trainer_state = client_trainer.loop(init_client_trainer_state,
                                               examples)
    yield client_trainer_state
예제 #5
0
def _train_multiple_clients_parallel(
    federated_data: FederatedData, client_ids: List[str],
    client_trainer: ClientTrainer[T], init_client_trainer_state: T,
    rng_seq: PRNGSequence,
    client_data_hparams: dataset_util.ClientDataHParams) -> Iterator[Any]:
  """Trains separate model for each client and records client updates.

  It has the same inputs and return values as `train_multiple_clients` above,
  but parallelizes the training across multiple devices if available.

  Args:
    federated_data: Federated data separated per client.
    client_ids: Ids of clients to train.
    client_trainer: ClientTrainer instance.
    init_client_trainer_state: Initial client trainer state. This will typically
      be derived from algorithm state before calling `train_multiple_clients`.
    rng_seq: Random key generator.
    client_data_hparams: Hyperparameters for client dataset preparation. The
      `drop_remainder` field is automatically set to True to ensure that all
      batches have the same batch size.

  Yields:
    Output of client trainer that is typically just an updated version of the
      input `init_client_trainer_state`. However, output is flexible.
  """
  client_data_hparams = client_data_hparams._replace(drop_remainder=True)
  client_data = []
  preprocessed_federated_data = federated_data.preprocess(
      lambda ds: dataset_util.preprocess_tf_dataset(ds, client_data_hparams))
  for _, client_dataset in prefetch.PrefetchClientDatasetsIterator(
      preprocessed_federated_data, client_ids):
    client_data.append(list(dataset_util.iterate(client_dataset)))
  # Sort by length to group similarly sized datasets together to minimize the
  # amount of fillvalues needed.
  client_data = sorted(client_data, key=len)
  # TODO(b/177346980): Handle case where all clients' data is empty by padding.
  fillvalue = _get_fillvalue(client_data)

  num_local_devices = jax.local_device_count()
  init_stack_state = tree_util.tree_broadcast(
      init_client_trainer_state, axis_size=num_local_devices)
  stack_rng = jrandom.split(next(rng_seq), num_local_devices)

  quotient, remainder = divmod(len(client_ids), num_local_devices)
  num_iterations = quotient + bool(remainder)
  for i in range(num_iterations):
    stack_state = init_stack_state
    streams = client_data[num_local_devices * i:num_local_devices * (i + 1)]
    # Handle number of clients not divisible by num_devices.
    if len(streams) < num_local_devices:
      client_count = len(streams)
      streams.extend(
          [[fillvalue] for _ in range(num_local_devices - client_count)])
    else:
      client_count = num_local_devices

    for batches in itertools.zip_longest(*streams, fillvalue=fillvalue):
      mask = jnp.array([b is not fillvalue for b in batches])
      stack_mask = tree_util.tree_stack(mask)
      stack_batch = tree_util.tree_stack(batches)
      # Mask must be the first input.
      stack_state, stack_rng = _pmap_step(stack_mask, client_trainer,
                                          stack_state, stack_batch, stack_rng)

    # Unstack stack_state, yield each one.
    for j, final_state in enumerate(tree_util.tree_unstack(stack_state)):
      if j < client_count:
        yield final_state