예제 #1
0
 def test_build_factory_fails_invalid_argument(self):
     with self.assertRaises(ValueError):
         sampling.UnweightedReservoirSamplingFactory(sample_size=0)
     with self.assertRaises(ValueError):
         sampling.UnweightedReservoirSamplingFactory(sample_size=-1)
     with self.assertRaises(TypeError):
         sampling.UnweightedReservoirSamplingFactory(sample_size=None)
     with self.assertRaises(TypeError):
         sampling.UnweightedReservoirSamplingFactory(sample_size='5')
예제 #2
0
 def test_measurements_structure_value(self):
     process = sampling.UnweightedReservoirSamplingFactory(
         sample_size=1).create(
             computation_types.to_type(
                 collections.OrderedDict(a=TensorType(tf.float32),
                                         b=[
                                             TensorType(tf.float32, [2, 2]),
                                             TensorType(tf.bool)
                                         ])))
     state = process.initialize()
     output = process.next(state, [
         collections.OrderedDict(a=1.0,
                                 b=[[[1.0, np.nan], [np.inf, 4.0]], True]),
         collections.OrderedDict(a=2.0, b=[[[1.0, 2.0], [3.0, 4.0]], False
                                           ]),
         collections.OrderedDict(a=np.inf,
                                 b=[[[np.nan, 2.0], [3.0, 4.0]], True])
     ])
     self.assertEqual(
         output.measurements,
         collections.OrderedDict(
             # One client has non-infinte tensors for this leaf node.
             a=tf.constant(1, dtype=tf.int64),
             # Two clients have non-infinte tensors for this leaf node.
             b=[
                 tf.constant(2, dtype=tf.int64),
                 tf.constant(0, dtype=tf.int64)
             ]))
예제 #3
0
 def test_measurements_scalar_value(self):
     process = sampling.UnweightedReservoirSamplingFactory(
         sample_size=1).create(computation_types.to_type(tf.float32))
     state = process.initialize()
     output = process.next(state, [1.0, np.nan, np.inf, 2.0, 3.0])
     # Two clients' values are non-infinte.
     self.assertEqual(output.measurements, tf.constant(2, dtype=tf.int64))
예제 #4
0
 def test_construction_fails_with_invalid_aggregation_factory(self):
   aggregation_factory = sampling.UnweightedReservoirSamplingFactory(
       sample_size=1)
   with self.assertRaisesRegex(
       TypeError, 'does not produce a compatible `AggregationProcess`'):
     optimizer_utils.build_model_delta_optimizer_process(
         model_fn=model_examples.LinearRegression,
         model_to_client_delta_fn=DummyClientDeltaFn,
         server_optimizer_fn=tf.keras.optimizers.SGD,
         model_update_aggregation_factory=aggregation_factory)
예제 #5
0
 def test_create(self):
     factory = sampling.UnweightedReservoirSamplingFactory(sample_size=10)
     with self.subTest('scalar_aggregator'):
         factory.create(computation_types.to_type(tf.int32))
     with self.subTest('structure_aggregator'):
         factory.create(
             computation_types.to_type(
                 collections.OrderedDict(
                     a=TensorType(tf.int32),
                     b=[TensorType(tf.float32, [3]),
                        TensorType(tf.bool)])))
예제 #6
0
 def test_unfilled_reservoir(self):
     process = sampling.UnweightedReservoirSamplingFactory(
         sample_size=4).create(computation_types.to_type(tf.int32))
     state = process.initialize()
     # Create 3 client values to aggregate.
     client_values = tf.random.stateless_uniform(
         shape=(3, ),
         minval=None,
         seed=tf.convert_to_tensor((TEST_SEED, TEST_SEED)),
         dtype=tf.int32).numpy().tolist()
     output = process.next(state, client_values)
     self.assertCountEqual(output.result, client_values)
예제 #7
0
 def test_sample_size_limits(self, sample_size):
     process = sampling.UnweightedReservoirSamplingFactory(
         sample_size=sample_size).create(computation_types.to_type(
             tf.int32))
     state = process.initialize()
     output = process.next(
         state,
         # Create a 2  * sample_size values from clients.
         tf.random.stateless_uniform(shape=(sample_size * 2, ),
                                     minval=None,
                                     seed=tf.convert_to_tensor(
                                         (TEST_SEED, TEST_SEED)),
                                     dtype=tf.int32).numpy().tolist())
     self.assertEqual(output.result.shape, (sample_size, ))
예제 #8
0
def build_personalization_eval(model_fn,
                               personalize_fn_dict,
                               baseline_evaluate_fn,
                               max_num_clients=100,
                               context_tff_type=None):
  """Builds the TFF computation for evaluating personalization strategies.

  The returned TFF computation broadcasts model weights from `tff.SERVER` to
  `tff.CLIENTS`. Each client evaluates the personalization strategies given in
  `personalize_fn_dict`. Evaluation metrics from at most `max_num_clients`
  participating clients are collected to the server.

  NOTE: The functions in `personalize_fn_dict` and `baseline_evaluate_fn` are
  expected to take as input *unbatched* datasets, and are responsible for
  applying batching, if any, to the provided input datasets.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    personalize_fn_dict: An `OrderedDict` that maps a `string` (representing a
      strategy name) to a no-argument function that returns a `tf.function`.
      Each `tf.function` represents a personalization strategy - it accepts a
      `tff.learning.Model` (with weights already initialized to the given model
      weights when users invoke the returned TFF computation), an unbatched
      `tf.data.Dataset` for train, an unbatched `tf.data.Dataset` for test, and
      an arbitrary context object (which is used to hold any extra information
      that a personalization strategy may use), trains a personalized model, and
      returns the evaluation metrics. The evaluation metrics are represented as
      an `OrderedDict` (or a nested `OrderedDict`) of `string` metric names to
      scalar `tf.Tensor`s.
    baseline_evaluate_fn: A `tf.function` that accepts a `tff.learning.Model`
      (with weights already initialized to the provided model weights when users
      invoke the returned TFF computation), and an unbatched `tf.data.Dataset`,
      evaluates the model on the dataset, and returns the evaluation metrics.
      The evaluation metrics are represented as an `OrderedDict` (or a nested
      `OrderedDict`) of `string` metric names to scalar `tf.Tensor`s. This
      function is *only* used to compute the baseline metrics of the initial
      model.
    max_num_clients: A positive `int` specifying the maximum number of clients
      to collect metrics in a round (default is 100). The clients are sampled
      without replacement. For each sampled client, all the personalization
      metrics from this client will be collected. If the number of participating
      clients in a round is smaller than this value, then metrics from all
      clients will be collected.
    context_tff_type: A `tff.Type` of the optional context object used by the
      personalization strategies defined in `personalization_fn_dict`. We use a
      context object to hold any extra information (in addition to the training
      dataset) that personalization may use. If context is used in
      `personalization_fn_dict`, its `tff.Type` must be provided here.

  Returns:
    A federated `tff.Computation` with the functional type signature
    `(<model_weights@SERVER, input@CLIENTS> -> personalization_metrics@SERVER)`:

    *   `model_weights` is a `tff.learning.ModelWeights`.
    *   Each client's input is an `OrderedDict` of two required keys
        `train_data` and `test_data`; each key is mapped to an unbatched
        `tf.data.Dataset`. If extra context (e.g., extra datasets) is used in
        `personalize_fn_dict`, then client input has a third key `context` that
        is mapped to a object whose `tff.Type` is provided by the
        `context_tff_type` argument.
    *   `personazliation_metrics` is an `OrderedDict` that maps a key
        'baseline_metrics' to the evaluation metrics of the initial model
        (computed by `baseline_evaluate_fn`), and maps keys (strategy names) in
        `personalize_fn_dict` to the evaluation metrics of the corresponding
        personalization strategies.
    *   Note: only metrics from at most `max_num_clients` participating clients
        (sampled without replacement) are collected to the SERVER. All collected
        metrics are stored in a single `OrderedDict` (`personalization_metrics`
        shown above), where each metric is mapped to a list of scalars (each
        scalar comes from one client). Metric values at the same position, e.g.,
        metric_1[i], metric_2[i]..., all come from the same client.

  Raises:
    TypeError: If arguments are of the wrong types.
    ValueError: If `baseline_metrics` is used as a key in `personalize_fn_dict`.
    ValueError: If `max_num_clients` is not positive.
  """
  # Obtain the types by constructing the model first.
  # TODO(b/124477628): Replace it with other ways of handling metadata.
  with tf.Graph().as_default():
    py_typecheck.check_callable(model_fn)
    model = model_fn()
    model_weights_type = model_utils.weights_type_from_model(model)
    batch_tff_type = computation_types.to_type(model.input_spec)

  # Define the `tff.Type` of each client's input. Since batching (as well as
  # other preprocessing of datasets) is done within each personalization
  # strategy (i.e., by functions in `personalize_fn_dict`), the client-side
  # input should contain unbatched elements.
  element_tff_type = _remove_batch_dim(batch_tff_type)
  client_input_type = collections.OrderedDict(
      train_data=computation_types.SequenceType(element_tff_type),
      test_data=computation_types.SequenceType(element_tff_type))
  if context_tff_type is not None:
    py_typecheck.check_type(context_tff_type, computation_types.Type)
    client_input_type['context'] = context_tff_type
  client_input_type = computation_types.to_type(client_input_type)

  @computations.tf_computation(model_weights_type, client_input_type)
  def _client_computation(initial_model_weights, client_input):
    """TFF computation that runs on each client."""
    train_data = client_input['train_data']
    test_data = client_input['test_data']
    context = client_input.get('context', None)

    final_metrics = collections.OrderedDict()
    # Compute the evaluation metrics of the initial model.
    final_metrics['baseline_metrics'] = _compute_baseline_metrics(
        model_fn, initial_model_weights, test_data, baseline_evaluate_fn)

    py_typecheck.check_type(personalize_fn_dict, collections.OrderedDict)
    if 'baseline_metrics' in personalize_fn_dict:
      raise ValueError('baseline_metrics should not be used as a key in '
                       'personalize_fn_dict.')

    # Compute the evaluation metrics of the personalized models. The returned
    # `p13n_metrics` is an `OrderedDict` that maps keys (strategy names) in
    # `personalize_fn_dict` to the evaluation metrics of the corresponding
    # personalization strategies.
    p13n_metrics = _compute_p13n_metrics(model_fn, initial_model_weights,
                                         train_data, test_data,
                                         personalize_fn_dict, context)
    final_metrics.update(p13n_metrics)
    return final_metrics

  py_typecheck.check_type(max_num_clients, int)
  if max_num_clients <= 0:
    raise ValueError('max_num_clients must be a positive integer.')

  reservoir_sampling_factory = sampling.UnweightedReservoirSamplingFactory(
      sample_size=max_num_clients)
  aggregation_process = reservoir_sampling_factory.create(
      _client_computation.type_signature.result)

  @computations.federated_computation(
      computation_types.FederatedType(model_weights_type, placements.SERVER),
      computation_types.FederatedType(client_input_type, placements.CLIENTS))
  def personalization_eval(server_model_weights, federated_client_input):
    """TFF orchestration logic."""
    client_init_weights = intrinsics.federated_broadcast(server_model_weights)
    client_final_metrics = intrinsics.federated_map(
        _client_computation, (client_init_weights, federated_client_input))

    # WARNING: Collecting information from clients can be risky. Users have to
    # make sure that it is proper to collect those metrics from clients.
    # TODO(b/147889283): Add a link to the TFF doc once it exists.
    sampling_output = aggregation_process.next(
        aggregation_process.initialize(),  # No state.
        client_final_metrics)
    # In the future we may want to output `sampling_output.measurements` also
    # but currently it is empty.
    return sampling_output.result

  return personalization_eval