def create_initial_state():
   return reconstruction_utils.ServerState(
       model=reconstruction_utils.get_global_variables(model_fn()),
       optimizer_state=(),
       round_num=tf.constant(0, dtype=tf.int64),
       aggregator_state=(),
   )
    def test_federated_reconstruction_evaluation_process_no_recon(
            self, model_fn):
        def loss_fn():
            return tf.keras.losses.MeanSquaredError()

        def metrics_fn():
            return [NumExamplesCounter(), NumOverCounter(5.0)]

        dataset_split_fn = reconstruction_utils.build_dataset_split_fn(
            recon_epochs_max=0, post_recon_epochs=2)

        evaluator = evaluation_computation.build_federated_reconstruction_evaluation_process(
            model_fn,
            loss_fn=loss_fn,
            metrics_fn=metrics_fn,
            reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1),
            dataset_split_fn=dataset_split_fn)
        self.assertEqual(
            str(evaluator.initialize.type_signature),
            '( -> <model=<trainable=<float32[1,1]>,non_trainable=<>>,'
            'optimizer_state=<>,round_num=int64,aggregator_state=<>>@SERVER)')
        self.assertEqual(
            str(evaluator.next.type_signature),
            '(<state=<model=<trainable=<float32[1,1]>,non_trainable=<>>,'
            'optimizer_state=<>,round_num=int64,aggregator_state=<>>@SERVER,'
            'data={<x=float32[?,1],y=float32[?,1]>*}@CLIENTS> -> '
            '<<model=<trainable=<float32[1,1]>,non_trainable=<>>,'
            'optimizer_state=<>,round_num=int64,aggregator_state=<>>@SERVER,'
            '<loss=float32,num_examples_total=float32,num_over=float32>@SERVER>)'
        )

        state = evaluator.initialize()
        state, metrics = evaluator.next(state, create_client_data())

        expected_keys = ['loss', 'num_examples_total', 'num_over']
        self.assertCountEqual(metrics.keys(), expected_keys)
        self.assertAlmostEqual(metrics['num_examples_total'], 12.0)
        self.assertAlmostEqual(metrics['num_over'], 6.0)

        # Without reconstruction and with an initialized model, we can expect an
        # exact value for loss.
        state = reconstruction_utils.ServerState(
            model=collections.OrderedDict([
                ('trainable', [[[1.0]]]),
                ('non_trainable', []),
            ]),
            optimizer_state=(),
            round_num=tf.constant(0, dtype=tf.int64),
            aggregator_state=(),
        )

        state, metrics = evaluator.next(state, create_client_data())

        expected_keys = ['loss', 'num_examples_total', 'num_over']
        self.assertCountEqual(metrics.keys(), expected_keys)
        # MSE is (y - 1 * x)^2 for each example, for a mean of
        # (4^2 + 4^2 + 5^2 + 4^2 + 3^2 + 6^2) / 6 = 59/3.
        self.assertAlmostEqual(metrics['loss'], 19.666666)
        self.assertAlmostEqual(metrics['num_examples_total'], 12.0)
        self.assertAlmostEqual(metrics['num_over'], 6.0)
Ejemplo n.º 3
0
 def server_init_tff():
     """Returns a `reconstruction_utils.ServerState` placed at `tff.SERVER`."""
     tf_init_tuple = tff.federated_eval(server_init_tf, tff.SERVER)
     aggregation_process_init = aggregation_process.initialize()
     return tff.federated_zip(
         reconstruction_utils.ServerState(
             model=tf_init_tuple[0],
             optimizer_state=tf_init_tuple[1],
             round_num=tf_init_tuple[2],
             aggregator_state=aggregation_process_init))
def build_federated_reconstruction_evaluation_process(
    model_fn: ModelFn,
    *,  # Callers pass below args by name.
    loss_fn: LossFn,
    metrics_fn: Optional[MetricsFn],
    reconstruction_optimizer_fn: OptimizerFn = functools.partial(
        tf.keras.optimizers.SGD, 0.1),
    dataset_split_fn: Optional[reconstruction_utils.DatasetSplitFn] = None
) -> tff.templates.IterativeProcess:
  """Builds an `IterativeProcess` for evaluation of `ReconstructionModel`s.

  The returned process wraps the `tff.Computation` returned by
  `build_federated_reconstruction_evaluation`, iteratively performing evaluation
  across clients for some number of rounds.

  Usage of returned process:
    eval_process = build_federated_reconstruction_evaluation_process(...)
    state = eval_process.initialize()
    state, metrics = eval_process(state, federated_data)

  Args:
    model_fn: A no-arg function that returns a `ReconstructionModel`. This
      method must *not* capture Tensorflow tensors or variables and use them.
      Must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    loss_fn: A no-arg function returning a `tf.keras.losses.Loss` to use to
      evaluate the model. The loss will be applied to the model's outputs during
      the evaluation stage. The final loss metric is the example-weighted mean
      loss across batches (and across clients).
    metrics_fn: A no-arg function returning a list of `tf.keras.metrics.Metric`s
      to evaluate the model. The metrics will be applied to the model's outputs
      during the evaluation stage. Final metric values are the example-weighted
      mean of metric values across batches (and across clients). If None, no
      metrics are applied.
    reconstruction_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer` used to reconstruct the local variables
      with the global ones frozen.
    dataset_split_fn: A `reconstruction_utils.DatasetSplitFn` taking in a client
      dataset and round number (always 0 for evaluation) and producing two TF
      datasets. The first is iterated over during reconstruction, and the second
      is iterated over during evaluation. This can be used to preprocess
      datasets to e.g. iterate over them for multiple epochs or use disjoint
      data for reconstruction and evaluation. If None, split client data in half
      for each user, using one half for reconstruction and the other for
      evaluation. See `reconstruction_utils.build_dataset_split_fn` for options.

  Returns:
    `tff.templates.IterativeProcess` constructed from the `tff.Computation`
    returned by `build_federated_reconstruction_evaluation`.
  """
  eval_comp = build_federated_reconstruction_evaluation(
      model_fn=model_fn,
      loss_fn=loss_fn,
      metrics_fn=metrics_fn,
      reconstruction_optimizer_fn=reconstruction_optimizer_fn,
      dataset_split_fn=dataset_split_fn)

  server_state_type = tff.type_at_server(
      reconstruction_utils.ServerState(
          model=eval_comp.type_signature.parameter[0].member,
          # There is no server optimizer in eval, so the optimizer_state is
          # empty.
          optimizer_state=(),
          round_num=tf.TensorSpec((), dtype=tf.int64),
          # Aggregations are stateless for evaluation.
          aggregator_state=(),
      ))
  batch_type = eval_comp.type_signature.parameter[1]

  @tff.tf_computation()
  def create_initial_state():
    return reconstruction_utils.ServerState(
        model=reconstruction_utils.get_global_variables(model_fn()),
        optimizer_state=(),
        round_num=tf.constant(0, dtype=tf.int64),
        aggregator_state=(),
    )

  @tff.federated_computation()
  def initialize():
    return tff.federated_value(create_initial_state(), tff.SERVER)

  @tff.federated_computation(server_state_type, batch_type)
  def eval_next(state, data):
    metrics = eval_comp(state.model, data)
    return state, metrics

  return tff.templates.IterativeProcess(initialize, eval_next)