Example #1
0
def build_federated_evaluation(model_fn):
    """Builds the TFF computation for federated evaluation of the given model.

  Args:
    model_fn: A no-argument function that returns a `tff.learning.Model`.

  Returns:
    A federated computation (an instance of `tff.Computation`) that accepts
    model parameters and federated data, and returns the evaluation metrics
    as aggregated by `tff.learning.Model.federated_output_computation`.
  """
    # Construct the model first just to obtain the metadata and define all the
    # types needed to define the computations that follow.
    # TODO(b/124477628): Ideally replace the need for stamping throwaway models
    # with some other mechanism.
    with tf.Graph().as_default():
        model = model_utils.enhance(model_fn())
        model_weights_type = tff.to_type(
            tf.nest.map_structure(
                lambda v: tff.TensorType(v.dtype.base_dtype, v.shape),
                model.weights))
        batch_type = tff.to_type(model.input_spec)

    @tff.tf_computation(model_weights_type, tff.SequenceType(batch_type))
    def client_eval(incoming_model_weights, dataset):
        """Returns local outputs after evaluting `model_weights` on `dataset`."""
        model = model_utils.enhance(model_fn())

        # TODO(b/124477598): Remove dummy when b/121400757 has been fixed.
        @tf.function
        def reduce_fn(dummy, batch):
            model_output = model.forward_pass(batch, training=False)
            return dummy + tf.cast(model_output.loss, tf.float64)

        # TODO(b/123898430): The control dependencies below have been inserted as a
        # temporary workaround. These control dependencies need to be removed, and
        # defuns and datasets supported together fully.
        with tf.control_dependencies(
            [tff.utils.assign(model.weights, incoming_model_weights)]):
            dummy = dataset.reduce(tf.constant(0.0, dtype=tf.float64),
                                   reduce_fn)

        with tf.control_dependencies([dummy]):
            return collections.OrderedDict([
                ('local_outputs', model.report_local_outputs()),
                ('workaround for b/121400757', dummy)
            ])

    @tff.federated_computation(
        tff.FederatedType(model_weights_type, tff.SERVER),
        tff.FederatedType(tff.SequenceType(batch_type), tff.CLIENTS))
    def server_eval(server_model_weights, federated_dataset):
        client_outputs = tff.federated_map(
            client_eval,
            [tff.federated_broadcast(server_model_weights), federated_dataset])
        return model.federated_output_computation(client_outputs.local_outputs)

    return server_eval
def build_federated_evaluation(model_fn):
    """Builds the TFF computation for federated evaluation of the given model.

  Args:
    model_fn: A no-argument function that returns a `tff.learning.Model`.

  Returns:
    A federated computation (an instance of `tff.Computation`) that accepts
    model parameters and federated data, and returns the evaluation metrics
    as aggregated by `tff.learning.Model.federated_output_computation`.
  """
    # Construct the model first just to obtain the metadata and define all the
    # types needed to define the computations that follow.
    # TODO(b/124477628): Ideally replace the need for stamping throwaway models
    # with some other mechanism.
    with tf.Graph().as_default():
        model = model_utils.enhance(model_fn())
        model_weights_type = tff.to_type(
            tf.nest.map_structure(
                lambda v: tff.TensorType(v.dtype.base_dtype, v.shape),
                model.weights))
        batch_type = tff.to_type(model.input_spec)

    @tff.tf_computation(model_weights_type, tff.SequenceType(batch_type))
    def client_eval(incoming_model_weights, dataset):
        """Returns local outputs after evaluting `model_weights` on `dataset`."""

        model = model_utils.enhance(model_fn())

        @tf.function
        def _tf_client_eval(incoming_model_weights, dataset):
            """Evaluation TF work."""

            tff.utils.assign(model.weights, incoming_model_weights)

            def reduce_fn(prev_loss, batch):
                model_output = model.forward_pass(batch, training=False)
                return prev_loss + tf.cast(model_output.loss, tf.float64)

            dataset.reduce(tf.constant(0.0, dtype=tf.float64), reduce_fn)

            return collections.OrderedDict([('local_outputs',
                                             model.report_local_outputs())])

        return _tf_client_eval(incoming_model_weights, dataset)

    @tff.federated_computation(
        tff.FederatedType(model_weights_type, tff.SERVER),
        tff.FederatedType(tff.SequenceType(batch_type), tff.CLIENTS))
    def server_eval(server_model_weights, federated_dataset):
        client_outputs = tff.federated_map(
            client_eval,
            [tff.federated_broadcast(server_model_weights), federated_dataset])
        return model.federated_output_computation(client_outputs.local_outputs)

    return server_eval
  def test_success_with_valid_context(self):

    def model_fn():
      return model_examples.LinearRegression(feature_dim=2)

    zero_model_weights = _create_zero_model_weights(model_fn)
    p13n_fn_dict = _create_p13n_fn_dict(learning_rate=1.0)

    # Build the p13n eval with an extra `context` argument.
    context_tff_type = tff.to_type(tf.int32)
    federated_p13n_eval = p13n_eval.build_personalization_eval(
        model_fn, p13n_fn_dict, _evaluate_fn, context_tff_type=context_tff_type)

    # Perform p13n eval on two clients with different `context` values.
    results = federated_p13n_eval(zero_model_weights, [
        _create_client_input(train_scale=1.0, test_scale=1.0, context=2),
        _create_client_input(train_scale=1.0, test_scale=2.0, context=5)
    ])
    results = results._asdict(recursive=True)

    bs1_metrics = results['batch_size_1']
    bs2_metrics = results['batch_size_2']

    # Number of training examples is `3 + context` for both clients.
    # Note: the order is not preserved due to `federated_sample`, but the order
    # should be consistent across different personalization strategies.
    self.assertAllEqual(sorted(bs1_metrics['num_examples']), [5, 8])
    self.assertAllEqual(bs1_metrics['num_examples'],
                        bs2_metrics['num_examples'])
  def test_failure_with_invalid_context_type(self):

    def model_fn():
      return model_examples.LinearRegression(feature_dim=2)

    zero_model_weights = _create_zero_model_weights(model_fn)
    p13n_fn_dict = _create_p13n_fn_dict(learning_rate=1.0)

    with self.assertRaises(TypeError):
      # `tf.int32` is not a `tff.Type`.
      bad_context_tff_type = tf.int32
      federated_p13n_eval = p13n_eval.build_personalization_eval(
          model_fn,
          p13n_fn_dict,
          _evaluate_fn,
          context_tff_type=bad_context_tff_type)

    with self.assertRaises(TypeError):
      # `context_tff_type` is provided but `context` is not provided.
      context_tff_type = tff.to_type(tf.int32)
      federated_p13n_eval = p13n_eval.build_personalization_eval(
          model_fn,
          p13n_fn_dict,
          _evaluate_fn,
          context_tff_type=context_tff_type)
      federated_p13n_eval(zero_model_weights, [
          _create_client_input(train_scale=1.0, test_scale=1.0, context=None),
          _create_client_input(train_scale=1.0, test_scale=2.0, context=None)
      ])
def build_personalization_eval(model_fn,
                               personalize_fn_dict,
                               baseline_evaluate_fn,
                               max_num_samples=100,
                               context_tff_type=None):
  """Builds the TFF computation for evaluating personalization strategies.

  The returned TFF computation broadcasts model weights from SERVER to CLIENTS.
  Each client evaluates the personalization strategies given in
  `personalize_fn_dict`. Evaluation metrics from at most `max_num_samples`
  participating clients are collected to the SERVER.

  Args:
    model_fn: A no-argument function that returns a `tff.learning.Model`.
    personalize_fn_dict: An `OrderedDict` that maps a `string` (representing a
      strategy name) to a no-argument function that returns a `tf.function`.
      Each `tf.function` represents a personalization strategy: it accepts a
      `tff.learning.Model` (with weights already initialized to the provided
      model weights when users invoke the returned TFF computation), a training
      `tf.dataset.Dataset`, a test `tf.dataset.Dataset`, and an arbitrary
      context object (which is used to hold any extra information that a
      personalization strategy may use), trains a personalized model, and
      returns the evaluation metrics. The evaluation metrics are usually
      represented as an `OrderedDict` (or a nested `OrderedDict`) of `string`
      metric names to scalar `tf.Tensor`s.
    baseline_evaluate_fn: A `tf.function` that accepts a `tff.learning.Model`
      (with weights already initialized to the provided model weights when users
      invoke the returned TFF computation), and a `tf.dataset.Dataset`,
      evaluates the model on the dataset, and returns the evaluation metrics.
      The evaluation metrics are usually represented as an `OrderedDict` (or a
      nested `OrderedDict`) of `string` metric names to scalar `tf.Tensor`s.
      This function is *only* used to compute the baseline metrics of the
      initial model.
    max_num_samples: A positive `int` specifying the maximum number of metric
      samples to collect in a round. Each sample contains the personalization
      metrics from a single client. If the number of participating clients in a
      round is smaller than this value, all clients' metrics are collected.
    context_tff_type: A `tff.Type` of the optional context object used by the
      personalization strategies defined in `personalization_fn_dict`. We use a
      context object to hold any extra information (in addition to the training
      dataset) that personalization may use. If context is used in
      `personalization_fn_dict`, its `tff.Type` must be provided here.

  Returns:
    A federated `tff.Computation` that maps
    < model_weights@SERVER, input@CLIENTS > -> personalization_metrics@SERVER,
    where:
    - model_weights is a `tff.learning.framework.ModelWeights`.
    - each client's input is an `OrderedDict` of at least two keys `train_data`
      and `test_data`, and each key is mapped to a `tf.dataset.Dataset`. If
      context is used in `personalize_fn_dict`, then client input has a third
      key `context` that is mapped to a object whose `tff.Type` is provided by
      the `context_tff_type` argument.
    - personazliation_metrics is an `OrderedDict` that maps a key
      'baseline_metrics' to the evaluation metrics of the initial model
      (computed by `baseline_evaluate_fn`), and maps keys (strategy names) in
      `personalize_fn_dict` to the evaluation metrics of the corresponding
      personalization strategies.
    - Note: only metrics from at most `max_num_samples` participating clients
      are collected to the SERVER. All collected metrics are stored in a
      single `OrderedDict` (the personalization_metrics shown above), where each
      metric is mapped to a list of scalars (each scalar comes from one client).
      Metric values at the same position, e.g., metric_1[i], metric_2[i]..., all
      come from the same client.

  Raises:
    TypeError: If arguments are of the wrong types.
    ValueError: If `baseline_metrics` is used as a key in `personalize_fn_dict`.
    ValueError: If `max_num_samples` is not positive.
  """
  # Obtain the types by constructing the model first.
  # TODO(b/124477628): Replace it with other ways of handling metadata.
  with tf.Graph().as_default():
    py_typecheck.check_callable(model_fn)
    model = model_utils.enhance(model_fn())
    model_weights_type = tff.framework.type_from_tensors(model.weights)
    batch_type = tff.to_type(model.input_spec)

  # Define the `tff.Type` of each client's input.
  client_input_type = collections.OrderedDict([
      ('train_data', tff.SequenceType(batch_type)),
      ('test_data', tff.SequenceType(batch_type))
  ])
  if context_tff_type is not None:
    py_typecheck.check_type(context_tff_type, tff.Type)
    client_input_type['context'] = context_tff_type
  client_input_type = tff.to_type(client_input_type)

  @tff.tf_computation(model_weights_type, client_input_type)
  def _client_computation(initial_model_weights, client_input):
    """TFF computation that runs on each client."""
    model = model_fn()
    train_data = client_input['train_data']
    test_data = client_input['test_data']
    context = client_input.get('context', None)
    return _client_fn(model, initial_model_weights, train_data, test_data,
                      personalize_fn_dict, baseline_evaluate_fn, context)

  py_typecheck.check_type(max_num_samples, int)
  if max_num_samples <= 0:
    raise ValueError('max_num_samples must be a positive integer.')

  @tff.federated_computation(
      tff.FederatedType(model_weights_type, tff.SERVER),
      tff.FederatedType(client_input_type, tff.CLIENTS))
  def personalization_eval(server_model_weights, federated_client_input):
    """TFF orchestration logic."""
    client_init_weights = tff.federated_broadcast(server_model_weights)
    client_final_metrics = tff.federated_map(
        _client_computation, (client_init_weights, federated_client_input))

    # WARNING: Collecting information from clients can be risky. Users have to
    # make sure that it is proper to collect those metrics from clients.
    # TODO(b/147889283): Add a link to the TFF doc once it exists.
    results = tff.utils.federated_sample(client_final_metrics, max_num_samples)
    return results

  return personalization_eval
def build_personalization_eval(model_fn,
                               personalize_fn_dict,
                               baseline_evaluate_fn,
                               max_num_samples=100,
                               context_tff_type=None):
    """Builds the TFF computation for evaluating personalization strategies.

  The returned TFF computation broadcasts model weights from `tff.SERVER` to
  `tff.CLIENTS`. Each client evaluates the personalization strategies given in
  `personalize_fn_dict`. Evaluation metrics from at most `max_num_samples`
  participating clients are collected to the server.

  NOTE: The functions in `personalize_fn_dict` and `baseline_evaluate_fn` are
  expected to take as input *unbatched* datasets, and are responsible for
  applying batching, if any, to the provided input datasets.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    personalize_fn_dict: An `OrderedDict` that maps a `string` (representing a
      strategy name) to a no-argument function that returns a `tf.function`.
      Each `tf.function` represents a personalization strategy: it accepts a
      `tff.learning.Model` (with weights already initialized to the given model
      weights when users invoke the returned TFF computation), an unbatched
      `tf.data.Dataset` for train, an unbatched `tf.data.Dataset` for test, and
      an arbitrary context object (which is used to hold any extra information
      that a personalization strategy may use), trains a personalized model, and
      returns the evaluation metrics. The evaluation metrics are represented as
      an `OrderedDict` (or a nested `OrderedDict`) of `string` metric names to
      scalar `tf.Tensor`s.
    baseline_evaluate_fn: A `tf.function` that accepts a `tff.learning.Model`
      (with weights already initialized to the provided model weights when users
      invoke the returned TFF computation), and an unbatched `tf.data.Dataset`,
      evaluates the model on the dataset, and returns the evaluation metrics.
      The evaluation metrics are represented as an `OrderedDict` (or a nested
      `OrderedDict`) of `string` metric names to scalar `tf.Tensor`s. This
      function is *only* used to compute the baseline metrics of the initial
      model.
    max_num_samples: A positive `int` specifying the maximum number of metric
      samples to collect in a round. Each sample contains the personalization
      metrics from a single client. If the number of participating clients in a
      round is smaller than this value, all clients' metrics are collected.
    context_tff_type: A `tff.Type` of the optional context object used by the
      personalization strategies defined in `personalization_fn_dict`. We use a
      context object to hold any extra information (in addition to the training
      dataset) that personalization may use. If context is used in
      `personalization_fn_dict`, its `tff.Type` must be provided here.

  Returns:
    A federated `tff.Computation` with the functional type signature
    `(<model_weights@SERVER, input@CLIENTS> -> personalization_metrics@SERVER)`:

    *   `model_weights` is a `tff.learning.ModelWeights`.
    *   Each client's input is an `OrderedDict` of two required keys
        `train_data` and `test_data`; each key is mapped to an unbatched
        `tf.data.Dataset`. If extra context (e.g., extra datasets) is used in
        `personalize_fn_dict`, then client input has a third key `context` that
        is mapped to a object whose `tff.Type` is provided by the
        `context_tff_type` argument.
    *   `personazliation_metrics` is an `OrderedDict` that maps a key
        'baseline_metrics' to the evaluation metrics of the initial model
        (computed by `baseline_evaluate_fn`), and maps keys (strategy names) in
        `personalize_fn_dict` to the evaluation metrics of the corresponding
        personalization strategies.
    *   Note: only metrics from at most `max_num_samples` participating clients
        (sampled without replacement) are collected to the SERVER. All collected
        metrics are stored in a single `OrderedDict` (`personalization_metrics`
        shown above), where each metric is mapped to a list of scalars (each
        scalar comes from one client). Metric values at the same position, e.g.,
        metric_1[i], metric_2[i]..., all come from the same client.

  Raises:
    TypeError: If arguments are of the wrong types.
    ValueError: If `baseline_metrics` is used as a key in `personalize_fn_dict`.
    ValueError: If `max_num_samples` is not positive.
  """
    # Obtain the types by constructing the model first.
    # TODO(b/124477628): Replace it with other ways of handling metadata.
    with tf.Graph().as_default():
        py_typecheck.check_callable(model_fn)
        model = model_utils.enhance(model_fn())
        model_weights_type = tff.framework.type_from_tensors(model.weights)
        batch_type = model.input_spec

    # Define the `tff.Type` of each client's input. Since batching (as well as
    # other preprocessing of datasets) is done within each personalization
    # strategy (i.e., by functions in `personalize_fn_dict`), the client-side
    # input should contain unbatched elements.
    element_type = _remove_batch_dim(batch_type)
    client_input_type = collections.OrderedDict([
        ('train_data', tff.SequenceType(element_type)),
        ('test_data', tff.SequenceType(element_type))
    ])
    if context_tff_type is not None:
        py_typecheck.check_type(context_tff_type, tff.Type)
        client_input_type['context'] = context_tff_type
    client_input_type = tff.to_type(client_input_type)

    @tff.tf_computation(model_weights_type, client_input_type)
    def _client_computation(initial_model_weights, client_input):
        """TFF computation that runs on each client."""
        train_data = client_input['train_data']
        test_data = client_input['test_data']
        context = client_input.get('context', None)

        final_metrics = collections.OrderedDict()
        # Compute the evaluation metrics of the initial model.
        final_metrics['baseline_metrics'] = _compute_baseline_metrics(
            model_fn, initial_model_weights, test_data, baseline_evaluate_fn)

        py_typecheck.check_type(personalize_fn_dict, collections.OrderedDict)
        if 'baseline_metrics' in personalize_fn_dict:
            raise ValueError('baseline_metrics should not be used as a key in '
                             'personalize_fn_dict.')

        # Compute the evaluation metrics of the personalized models. The returned
        # `p13n_metrics` is an `OrderedDict` that maps keys (strategy names) in
        # `personalize_fn_dict` to the evaluation metrics of the corresponding
        # personalization strategies.
        p13n_metrics = _compute_p13n_metrics(model_fn, initial_model_weights,
                                             train_data, test_data,
                                             personalize_fn_dict, context)
        final_metrics.update(p13n_metrics)
        return final_metrics

    py_typecheck.check_type(max_num_samples, int)
    if max_num_samples <= 0:
        raise ValueError('max_num_samples must be a positive integer.')

    @tff.federated_computation(
        tff.FederatedType(model_weights_type, tff.SERVER),
        tff.FederatedType(client_input_type, tff.CLIENTS))
    def personalization_eval(server_model_weights, federated_client_input):
        """TFF orchestration logic."""
        client_init_weights = tff.federated_broadcast(server_model_weights)
        client_final_metrics = tff.federated_map(
            _client_computation, (client_init_weights, federated_client_input))

        # WARNING: Collecting information from clients can be risky. Users have to
        # make sure that it is proper to collect those metrics from clients.
        # TODO(b/147889283): Add a link to the TFF doc once it exists.
        results = tff.utils.federated_sample(client_final_metrics,
                                             max_num_samples)
        return results

    return personalization_eval