def test_build_dataset_split_fn_recon_max_steps(self):
    # 3 batches.
    client_dataset = tf.data.Dataset.range(6).batch(2)

    split_dataset_fn = reconstruction_utils.build_dataset_split_fn(
        recon_epochs_max=2, recon_steps_max=4)
    # Round number shouldn't matter.
    recon_dataset, post_recon_dataset = split_dataset_fn(client_dataset, 3)

    recon_list = list(recon_dataset.as_numpy_iterator())
    post_recon_list = list(post_recon_dataset.as_numpy_iterator())

    self.assertAllEqual(recon_list, [[0, 1], [2, 3], [4, 5], [0, 1]])
    self.assertAllEqual(post_recon_list, [[0, 1], [2, 3], [4, 5]])

    # Adding more steps than the number of actual steps has no effect.
    split_dataset_fn = reconstruction_utils.build_dataset_split_fn(
        recon_epochs_max=2, recon_steps_max=7)
    # Round number shouldn't matter.
    recon_dataset, post_recon_dataset = split_dataset_fn(client_dataset, 3)

    recon_list = list(recon_dataset.as_numpy_iterator())
    post_recon_list = list(post_recon_dataset.as_numpy_iterator())

    self.assertAllEqual(recon_list,
                        [[0, 1], [2, 3], [4, 5], [0, 1], [2, 3], [4, 5]])
    self.assertAllEqual(post_recon_list, [[0, 1], [2, 3], [4, 5]])
  def test_build_dataset_split_fn_recon_epochs_variable(self):
    # 3 batches.
    client_dataset = tf.data.Dataset.range(6).batch(2)

    split_dataset_fn = reconstruction_utils.build_dataset_split_fn(
        recon_epochs_max=8, recon_epochs_constant=False)

    round_num = tf.constant(1, dtype=tf.int64)
    recon_dataset, post_recon_dataset = split_dataset_fn(
        client_dataset, round_num)

    recon_list = list(recon_dataset.as_numpy_iterator())
    post_recon_list = list(post_recon_dataset.as_numpy_iterator())

    self.assertAllEqual(recon_list, [[0, 1], [2, 3], [4, 5]])
    self.assertAllEqual(post_recon_list, [[0, 1], [2, 3], [4, 5]])

    round_num = tf.constant(2, dtype=tf.int64)
    recon_dataset, post_recon_dataset = split_dataset_fn(
        client_dataset, round_num)

    recon_list = list(recon_dataset.as_numpy_iterator())
    post_recon_list = list(post_recon_dataset.as_numpy_iterator())

    self.assertAllEqual(recon_list,
                        [[0, 1], [2, 3], [4, 5], [0, 1], [2, 3], [4, 5]])
    self.assertAllEqual(post_recon_list, [[0, 1], [2, 3], [4, 5]])
  def test_build_dataset_split_fn_split_dataset_one_batch(self):
    """Ensures clients without any data don't fail."""
    # 1 batch. Batch size can be larger than number of examples.
    client_dataset = tf.data.Dataset.range(1).batch(4)

    split_dataset_fn = reconstruction_utils.build_dataset_split_fn(
        split_dataset=True)

    # Round number doesn't matter.
    round_num = tf.constant(1, dtype=tf.int64)
    recon_dataset, post_recon_dataset = split_dataset_fn(
        client_dataset, round_num)

    recon_list = list(recon_dataset.as_numpy_iterator())
    post_recon_list = list(post_recon_dataset.as_numpy_iterator())

    self.assertAllEqual(recon_list, [[0]])
    self.assertAllEqual(post_recon_list, [])

    # Round number doesn't matter.
    round_num = tf.constant(2, dtype=tf.int64)
    recon_dataset, post_recon_dataset = split_dataset_fn(
        client_dataset, round_num)

    recon_list = list(recon_dataset.as_numpy_iterator())
    post_recon_list = list(post_recon_dataset.as_numpy_iterator())

    self.assertAllEqual(recon_list, [[0]])
    self.assertAllEqual(post_recon_list, [])
  def test_build_dataset_split_fn_split_dataset_even_batches(self):
    # 4 batches.
    client_dataset = tf.data.Dataset.range(8).batch(2)

    split_dataset_fn = reconstruction_utils.build_dataset_split_fn(
        split_dataset=True)

    # Round number doesn't matter.
    round_num = tf.constant(1, dtype=tf.int64)
    recon_dataset, post_recon_dataset = split_dataset_fn(
        client_dataset, round_num)

    recon_list = list(recon_dataset.as_numpy_iterator())
    post_recon_list = list(post_recon_dataset.as_numpy_iterator())

    self.assertAllEqual(recon_list, [[0, 1], [4, 5]])
    self.assertAllEqual(post_recon_list, [[2, 3], [6, 7]])

    # Round number doesn't matter.
    round_num = tf.constant(2, dtype=tf.int64)
    recon_dataset, post_recon_dataset = split_dataset_fn(
        client_dataset, round_num)

    recon_list = list(recon_dataset.as_numpy_iterator())
    post_recon_list = list(post_recon_dataset.as_numpy_iterator())

    self.assertAllEqual(recon_list, [[0, 1], [4, 5]])
    self.assertAllEqual(post_recon_list, [[2, 3], [6, 7]])
  def test_build_dataset_split_fn_split_dataset_zero_batches(self):
    """Ensures clients without any data don't fail."""
    # 0 batches.
    client_dataset = tf.data.Dataset.range(0).batch(2)

    split_dataset_fn = reconstruction_utils.build_dataset_split_fn(
        split_dataset=True)

    # Round number doesn't matter.
    round_num = tf.constant(1, dtype=tf.int64)
    recon_dataset, post_recon_dataset = split_dataset_fn(
        client_dataset, round_num)

    recon_list = list(recon_dataset.as_numpy_iterator())
    post_recon_list = list(post_recon_dataset.as_numpy_iterator())

    self.assertAllEqual(recon_list, [])
    self.assertAllEqual(post_recon_list, [])

    # Round number doesn't matter.
    round_num = tf.constant(2, dtype=tf.int64)
    recon_dataset, post_recon_dataset = split_dataset_fn(
        client_dataset, round_num)

    recon_list = list(recon_dataset.as_numpy_iterator())
    post_recon_list = list(post_recon_dataset.as_numpy_iterator())

    self.assertAllEqual(recon_list, [])
    self.assertAllEqual(post_recon_list, [])
    def test_federated_reconstruction_evaluation_process_no_recon(
            self, model_fn):
        def loss_fn():
            return tf.keras.losses.MeanSquaredError()

        def metrics_fn():
            return [NumExamplesCounter(), NumOverCounter(5.0)]

        dataset_split_fn = reconstruction_utils.build_dataset_split_fn(
            recon_epochs_max=0, post_recon_epochs=2)

        evaluator = evaluation_computation.build_federated_reconstruction_evaluation_process(
            model_fn,
            loss_fn=loss_fn,
            metrics_fn=metrics_fn,
            reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1),
            dataset_split_fn=dataset_split_fn)
        self.assertEqual(
            str(evaluator.initialize.type_signature),
            '( -> <model=<trainable=<float32[1,1]>,non_trainable=<>>,'
            'optimizer_state=<>,round_num=int64,aggregator_state=<>>@SERVER)')
        self.assertEqual(
            str(evaluator.next.type_signature),
            '(<state=<model=<trainable=<float32[1,1]>,non_trainable=<>>,'
            'optimizer_state=<>,round_num=int64,aggregator_state=<>>@SERVER,'
            'data={<x=float32[?,1],y=float32[?,1]>*}@CLIENTS> -> '
            '<<model=<trainable=<float32[1,1]>,non_trainable=<>>,'
            'optimizer_state=<>,round_num=int64,aggregator_state=<>>@SERVER,'
            '<loss=float32,num_examples_total=float32,num_over=float32>@SERVER>)'
        )

        state = evaluator.initialize()
        state, metrics = evaluator.next(state, create_client_data())

        expected_keys = ['loss', 'num_examples_total', 'num_over']
        self.assertCountEqual(metrics.keys(), expected_keys)
        self.assertAlmostEqual(metrics['num_examples_total'], 12.0)
        self.assertAlmostEqual(metrics['num_over'], 6.0)

        # Without reconstruction and with an initialized model, we can expect an
        # exact value for loss.
        state = reconstruction_utils.ServerState(
            model=collections.OrderedDict([
                ('trainable', [[[1.0]]]),
                ('non_trainable', []),
            ]),
            optimizer_state=(),
            round_num=tf.constant(0, dtype=tf.int64),
            aggregator_state=(),
        )

        state, metrics = evaluator.next(state, create_client_data())

        expected_keys = ['loss', 'num_examples_total', 'num_over']
        self.assertCountEqual(metrics.keys(), expected_keys)
        # MSE is (y - 1 * x)^2 for each example, for a mean of
        # (4^2 + 4^2 + 5^2 + 4^2 + 3^2 + 6^2) / 6 = 59/3.
        self.assertAlmostEqual(metrics['loss'], 19.666666)
        self.assertAlmostEqual(metrics['num_examples_total'], 12.0)
        self.assertAlmostEqual(metrics['num_over'], 6.0)
  def test_build_dataset_split_fn_post_recon_multiple_epochs_max_steps(self):
    # 3 batches.
    client_dataset = tf.data.Dataset.range(6).batch(2)

    split_dataset_fn = reconstruction_utils.build_dataset_split_fn(
        post_recon_epochs=2, post_recon_steps_max=4)

    # Round number doesn't matter.
    round_num = tf.constant(1, dtype=tf.int64)
    recon_dataset, post_recon_dataset = split_dataset_fn(
        client_dataset, round_num)

    recon_list = list(recon_dataset.as_numpy_iterator())
    post_recon_list = list(post_recon_dataset.as_numpy_iterator())

    self.assertAllEqual(recon_list, [[0, 1], [2, 3], [4, 5]])
    self.assertAllEqual(post_recon_list, [[0, 1], [2, 3], [4, 5], [0, 1]])

    # Round number doesn't matter.
    round_num = tf.constant(2, dtype=tf.int64)
    recon_dataset, post_recon_dataset = split_dataset_fn(
        client_dataset, round_num)

    recon_list = list(recon_dataset.as_numpy_iterator())
    post_recon_list = list(post_recon_dataset.as_numpy_iterator())

    self.assertAllEqual(recon_list, [[0, 1], [2, 3], [4, 5]])
    self.assertAllEqual(post_recon_list, [[0, 1], [2, 3], [4, 5], [0, 1]])
    def test_federated_reconstruction_metrics_none_loss_decreases(
            self, model_fn):
        def loss_fn():
            return tf.keras.losses.MeanSquaredError()

        dataset_split_fn = reconstruction_utils.build_dataset_split_fn(
            recon_epochs_max=3)

        evaluate = evaluation_computation.build_federated_reconstruction_evaluation(
            model_fn,
            loss_fn=loss_fn,
            metrics_fn=None,
            reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.01),
            dataset_split_fn=dataset_split_fn)
        self.assertEqual(
            str(evaluate.type_signature),
            '(<server_model_weights=<trainable=<float32[1,1]>,'
            'non_trainable=<>>@SERVER,federated_dataset={<x=float32[?,1],'
            'y=float32[?,1]>*}@CLIENTS> -> <loss=float32>@SERVER)')

        result = evaluate(
            collections.OrderedDict([
                ('trainable', [[[1.0]]]),
                ('non_trainable', []),
            ]), create_client_data())

        expected_keys = ['loss']
        self.assertCountEqual(result.keys(), expected_keys)
        # Ensure loss decreases from reconstruction vs. initializing the bias to 0.
        # MSE is (y - 1 * x)^2 for each example, for a mean of
        # (4^2 + 4^2 + 5^2 + 4^2 + 3^2 + 6^2) / 6 = 59/3.
        self.assertLess(result['loss'], 19.666666)
    def test_federated_reconstruction_no_split_data(self, model_fn):
        def loss_fn():
            return tf.keras.losses.MeanSquaredError()

        def metrics_fn():
            return [NumExamplesCounter(), NumOverCounter(5.0)]

        dataset_split_fn = reconstruction_utils.build_dataset_split_fn()

        evaluate = evaluation_computation.build_federated_reconstruction_evaluation(
            model_fn,
            loss_fn=loss_fn,
            metrics_fn=metrics_fn,
            reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1),
            dataset_split_fn=dataset_split_fn)
        self.assertEqual(
            str(evaluate.type_signature),
            '(<server_model_weights=<trainable=<float32[1,1]>,'
            'non_trainable=<>>@SERVER,federated_dataset={<x=float32[?,1],'
            'y=float32[?,1]>*}@CLIENTS> -> '
            '<loss=float32,num_examples_total=float32,num_over=float32>@SERVER)'
        )

        result = evaluate(
            collections.OrderedDict([
                ('trainable', [[[5.0]]]),
                ('non_trainable', []),
            ]), create_client_data())

        expected_keys = ['loss', 'num_examples_total', 'num_over']
        self.assertCountEqual(result.keys(), expected_keys)
        self.assertAlmostEqual(result['num_examples_total'], 6.0)
        self.assertAlmostEqual(result['num_over'], 3.0)
Exemplo n.º 10
0
    def test_custom_model_eval_reconstruction_multiple_epochs(self):
        client_data = create_emnist_client_data()
        train_data = [client_data(), client_data()]

        def loss_fn():
            return tf.keras.losses.SparseCategoricalCrossentropy()

        def metrics_fn():
            return [
                NumExamplesCounter(),
                NumBatchesCounter(),
                tf.keras.metrics.SparseCategoricalAccuracy()
            ]

        dataset_split_fn = reconstruction_utils.build_dataset_split_fn(
            recon_epochs_max=3,
            recon_epochs_constant=False,
            post_recon_epochs=4,
            post_recon_steps_max=3)
        trainer = training_process.build_federated_reconstruction_process(
            MnistModel,
            loss_fn=loss_fn,
            metrics_fn=metrics_fn,
            client_optimizer_fn=functools.partial(tf.keras.optimizers.SGD,
                                                  0.001),
            reconstruction_optimizer_fn=functools.partial(
                tf.keras.optimizers.SGD, 0.001),
            evaluate_reconstruction=True,
            dataset_split_fn=dataset_split_fn)
        state = trainer.initialize()

        outputs = []
        states = []
        for _ in range(2):
            state, output = trainer.next(state, train_data)
            outputs.append(output)
            states.append(state)

        self.assertLess(outputs[1]['loss'], outputs[0]['loss'])
        self.assertNotAllClose(states[0].model.trainable,
                               states[1].model.trainable)

        # Expect 6 reconstruction examples, 10 training examples.
        self.assertEqual(outputs[0]['num_examples_total'], 16.0)
        # Expect 12 reconstruction examples, 10 training examples.
        self.assertEqual(outputs[1]['num_examples_total'], 22.0)

        # Expect 4 reconstruction batches and 6 training batches.
        self.assertEqual(outputs[0]['num_batches_total'], 10.0)
        # Expect 8 reconstruction batches and 6 training batches.
        self.assertEqual(outputs[1]['num_batches_total'], 14.0)
  def test_build_dataset_split_fn(self):
    # 3 batches.
    client_dataset = tf.data.Dataset.range(6).batch(2)

    split_dataset_fn = reconstruction_utils.build_dataset_split_fn(
        recon_epochs_max=2, post_recon_epochs=1)
    # Round number shouldn't matter.
    recon_dataset, post_recon_dataset = split_dataset_fn(client_dataset, 3)

    recon_list = list(recon_dataset.as_numpy_iterator())
    post_recon_list = list(post_recon_dataset.as_numpy_iterator())

    self.assertAllEqual(recon_list,
                        [[0, 1], [2, 3], [4, 5], [0, 1], [2, 3], [4, 5]])
    self.assertAllEqual(post_recon_list, [[0, 1], [2, 3], [4, 5]])
Exemplo n.º 12
0
    def test_custom_model_eval_reconstruction_split_multiple_epochs(self):
        client_data = create_emnist_client_data()
        # 3 batches per user, each with one example. Since data will be split for
        # each user, each user will have 2 unique recon examples, and 1 unique
        # post-recon example (even-indices are allocated to recon during splitting).
        train_data = [client_data(batch_size=1), client_data(batch_size=1)]

        def loss_fn():
            return tf.keras.losses.SparseCategoricalCrossentropy()

        def metrics_fn():
            return [
                NumExamplesCounter(),
                NumBatchesCounter(),
                tf.keras.metrics.SparseCategoricalAccuracy()
            ]

        dataset_split_fn = reconstruction_utils.build_dataset_split_fn(
            recon_epochs_max=3, split_dataset=True, post_recon_epochs=5)
        trainer = training_process.build_federated_reconstruction_process(
            MnistModel,
            loss_fn=loss_fn,
            metrics_fn=metrics_fn,
            client_optimizer_fn=functools.partial(tf.keras.optimizers.SGD,
                                                  0.001),
            evaluate_reconstruction=True,
            dataset_split_fn=dataset_split_fn)
        state = trainer.initialize()

        outputs = []
        states = []
        for _ in range(2):
            state, output = trainer.next(state, train_data)
            outputs.append(output)
            states.append(state)

        self.assertLess(outputs[1]['loss'], outputs[0]['loss'])
        self.assertNotAllClose(states[0].model.trainable,
                               states[1].model.trainable)

        # Expect 12 reconstruction examples, 10 training examples.
        self.assertEqual(outputs[0]['num_examples_total'], 22.0)
        self.assertEqual(outputs[1]['num_examples_total'], 22.0)

        # Expect 12 reconstruction batches and 10 training batches.
        self.assertEqual(outputs[0]['num_batches_total'], 22.0)
        self.assertEqual(outputs[1]['num_batches_total'], 22.0)
Exemplo n.º 13
0
  def test_personal_matrix_factorization_trains_reconstruction_model(self):
    train_data = [
        self.train_users.flatten().tolist(),
        self.train_items.flatten().tolist(),
        self.train_preferences.flatten().tolist()
    ]
    train_tf_dataset = tf.data.Dataset.from_tensor_slices(
        list(zip(*train_data)))

    def batch_map_fn(example_batch):
      return collections.OrderedDict(
          x=tf.cast(example_batch[:, 0:1], tf.int64), y=example_batch[:, 1:2])

    train_tf_dataset = train_tf_dataset.batch(1).map(batch_map_fn).repeat(5)
    train_tf_datasets = [train_tf_dataset] * 2

    num_users = 1
    num_items = 8
    num_latent_factors = 10
    personal_model = True
    add_biases = False
    l2_regularization = 0.0

    tff_model_fn = models.build_reconstruction_model(
        functools.partial(
            models.get_matrix_factorization_model,
            num_users,
            num_items,
            num_latent_factors,
            personal_model=personal_model,
            add_biases=add_biases,
            l2_regularization=l2_regularization))

    # Also test `models.get_loss_fn` and `models.get_metrics_fn`.
    trainer = training_process.build_federated_reconstruction_process(
        tff_model_fn,
        loss_fn=models.get_loss_fn(),
        metrics_fn=models.get_metrics_fn(),
        client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1e-2),
        reconstruction_optimizer_fn=(
            lambda: tf.keras.optimizers.SGD(learning_rate=1e-3)),
        dataset_split_fn=reconstruction_utils.build_dataset_split_fn(
            recon_epochs_max=10))

    state = trainer.initialize()
    trainer.next(state, train_tf_datasets)
    def test_federated_reconstruction_evaluation_process(self, model_fn):
        def loss_fn():
            return tf.keras.losses.MeanSquaredError()

        def metrics_fn():
            return [NumExamplesCounter(), NumOverCounter(5.0)]

        dataset_split_fn = reconstruction_utils.build_dataset_split_fn(
            recon_epochs_max=2,
            post_recon_epochs=10,
            post_recon_steps_max=7,
            split_dataset=True)

        evaluator = evaluation_computation.build_federated_reconstruction_evaluation_process(
            model_fn,
            loss_fn=loss_fn,
            metrics_fn=metrics_fn,
            reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1),
            dataset_split_fn=dataset_split_fn)
        self.assertEqual(
            str(evaluator.initialize.type_signature),
            '( -> <model=<trainable=<float32[1,1]>,non_trainable=<>>,'
            'optimizer_state=<>,round_num=int64,aggregator_state=<>>@SERVER)')
        self.assertEqual(
            str(evaluator.next.type_signature),
            '(<state=<model=<trainable=<float32[1,1]>,non_trainable=<>>,'
            'optimizer_state=<>,round_num=int64,aggregator_state=<>>@SERVER,'
            'data={<x=float32[?,1],y=float32[?,1]>*}@CLIENTS> -> '
            '<<model=<trainable=<float32[1,1]>,non_trainable=<>>,'
            'optimizer_state=<>,round_num=int64,aggregator_state=<>>@SERVER,'
            '<loss=float32,num_examples_total=float32,num_over=float32>@SERVER>)'
        )

        state = evaluator.initialize()
        state, metrics = evaluator.next(state, create_client_data())

        expected_keys = ['loss', 'num_examples_total', 'num_over']
        self.assertCountEqual(metrics.keys(), expected_keys)
        self.assertAlmostEqual(metrics['num_examples_total'], 14.0)
        self.assertAlmostEqual(metrics['num_over'], 7.0)
    def test_federated_reconstruction_recon_lr_0(self, model_fn):
        def loss_fn():
            return tf.keras.losses.MeanSquaredError()

        def metrics_fn():
            return [NumExamplesCounter(), NumOverCounter(5.0)]

        dataset_split_fn = reconstruction_utils.build_dataset_split_fn()

        evaluate = evaluation_computation.build_federated_reconstruction_evaluation(
            model_fn,
            loss_fn=loss_fn,
            metrics_fn=metrics_fn,
            # Set recon optimizer LR to 0 so reconstruction has no effect.
            reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.0),
            dataset_split_fn=dataset_split_fn)
        self.assertEqual(
            str(evaluate.type_signature),
            '(<server_model_weights=<trainable=<float32[1,1]>,'
            'non_trainable=<>>@SERVER,federated_dataset={<x=float32[?,1],'
            'y=float32[?,1]>*}@CLIENTS> -> '
            '<loss=float32,num_examples_total=float32,num_over=float32>@SERVER)'
        )

        result = evaluate(
            collections.OrderedDict([
                ('trainable', [[[1.0]]]),
                ('non_trainable', []),
            ]), create_client_data())

        expected_keys = ['loss', 'num_examples_total', 'num_over']
        self.assertCountEqual(result.keys(), expected_keys)
        # Now have an expectation for loss since the local bias is initialized at 0
        # and not reconstructed. MSE is (y - 1 * x)^2 for each example, for a mean
        # of (4^2 + 4^2 + 5^2 + 4^2 + 3^2 + 6^2) / 6 = 59/3.
        self.assertAlmostEqual(result['loss'], 19.666666)
        self.assertAlmostEqual(result['num_examples_total'], 6.0)
        self.assertAlmostEqual(result['num_over'], 3.0)
def build_federated_reconstruction_evaluation(
    model_fn: ModelFn,
    *,  # Callers pass below args by name.
    loss_fn: LossFn,
    metrics_fn: Optional[MetricsFn],
    reconstruction_optimizer_fn: OptimizerFn = functools.partial(
        tf.keras.optimizers.SGD, 0.1),
    dataset_split_fn: Optional[reconstruction_utils.DatasetSplitFn] = None
) -> tff.Computation:
  """Builds a `tff.Computation` for evaluation of a `ReconstructionModel`.

  The returned computation proceeds in two stages: (1) reconstruction and (2)
  evaluation. During the reconstruction stage, local variables are reconstructed
  by freezing global variables and training using reconstruction_optimizer_fn.
  During the evaluation stage, the reconstructed local variables and global
  variables are evaluated using the provided loss_fn and metrics_fn.

  Usage of returned computation:
    eval_comp = build_federated_reconstruction_evaluation(...)
    metrics = eval_comp(reconstruction_utils.get_global_variables(model),
                        federated_data)

  Args:
    model_fn: A no-arg function that returns a `ReconstructionModel`. This
      method must *not* capture Tensorflow tensors or variables and use them.
      Must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    loss_fn: A no-arg function returning a `tf.keras.losses.Loss` to use to
      evaluate the model. The loss will be applied to the model's outputs during
      the evaluation stage. The final loss metric is the example-weighted mean
      loss across batches (and across clients).
    metrics_fn: A no-arg function returning a list of `tf.keras.metrics.Metric`s
      to evaluate the model. The metrics will be applied to the model's outputs
      during the evaluation stage. Final metric values are the example-weighted
      mean of metric values across batches (and across clients). If None, no
      metrics are applied.
    reconstruction_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer` used to reconstruct the local variables
      with the global ones frozen.
    dataset_split_fn: A `reconstruction_utils.DatasetSplitFn` taking in a client
      dataset and round number (always 0 for evaluation) and producing two TF
      datasets. The first is iterated over during reconstruction, and the second
      is iterated over during evaluation. This can be used to preprocess
      datasets to e.g. iterate over them for multiple epochs or use disjoint
      data for reconstruction and evaluation. If None, split client data in half
      for each user, using one half for reconstruction and the other for
      evaluation. See `reconstruction_utils.build_dataset_split_fn` for options.

  Raises:
    ValueError: if both `loss_fn` and `metrics_fn` are None.

  Returns:
    A `tff.Computation` that accepts model parameters and federated data and
    returns example-weighted evaluation loss and metrics.
  """
  # Construct the model first just to obtain the metadata and define all the
  # types needed to define the computations that follow.
  with tf.Graph().as_default():
    model = model_fn()
    global_weights = reconstruction_utils.get_global_variables(model)
    model_weights_type = tff.framework.type_from_tensors(global_weights)
    batch_type = tff.to_type(model.input_spec)
    metrics = [keras_utils.MeanLossMetric(loss_fn())]
    if metrics_fn is not None:
      metrics.extend(metrics_fn())
    if not metrics:
      raise ValueError(
          'One or both of metrics_fn and loss_fn should be provided.')
    federated_output_computation = (
        keras_utils.federated_output_computation_from_metrics(metrics))
    # Remove unneeded variables to avoid polluting namespace.
    del model
    del global_weights
    del metrics

  if dataset_split_fn is None:
    dataset_split_fn = reconstruction_utils.build_dataset_split_fn(
        split_dataset=True)

  @tff.tf_computation(model_weights_type, tff.SequenceType(batch_type))
  def client_computation(incoming_model_weights, client_dataset):
    """Reconstructs and evaluates with `incoming_model_weights`."""
    client_model = model_fn()
    client_global_weights = reconstruction_utils.get_global_variables(
        client_model)
    client_local_weights = reconstruction_utils.get_local_variables(
        client_model)
    metrics = [keras_utils.MeanLossMetric(loss_fn())]
    if metrics_fn is not None:
      metrics.extend(metrics_fn())
    batch_loss_fn = loss_fn()
    reconstruction_optimizer = reconstruction_optimizer_fn()

    @tf.function
    def reconstruction_reduce_fn(num_examples_sum, batch):
      """Runs reconstruction training on local client batch."""
      with tf.GradientTape() as tape:
        output = client_model.forward_pass(batch, training=True)
        batch_loss = batch_loss_fn(
            y_true=output.labels, y_pred=output.predictions)

      gradients = tape.gradient(batch_loss, client_local_weights.trainable)
      reconstruction_optimizer.apply_gradients(
          zip(gradients, client_local_weights.trainable))

      return num_examples_sum + output.num_examples

    @tf.function
    def evaluation_reduce_fn(num_examples_sum, batch):
      """Runs evaluation on client batch without training."""
      output = client_model.forward_pass(batch, training=False)
      # Update each metric.
      for metric in metrics:
        metric.update_state(y_true=output.labels, y_pred=output.predictions)
      return num_examples_sum + output.num_examples

    @tf.function
    def tf_client_computation(incoming_model_weights, client_dataset):
      """Reconstructs and evaluates with `incoming_model_weights`."""
      # Pass in fixed 0 round number during evaluation, since global variables
      # aren't being iteratively updated as in training.
      recon_dataset, eval_dataset = dataset_split_fn(
          client_dataset, tf.constant(0, dtype=tf.int64))

      # Assign incoming global weights to `client_model` before reconstruction.
      tf.nest.map_structure(lambda v, t: v.assign(t), client_global_weights,
                            incoming_model_weights)

      recon_dataset.reduce(tf.constant(0), reconstruction_reduce_fn)
      eval_dataset.reduce(tf.constant(0), evaluation_reduce_fn)

      eval_local_outputs = keras_utils.read_metric_variables(metrics)
      return eval_local_outputs

    return tf_client_computation(incoming_model_weights, client_dataset)

  @tff.federated_computation(
      tff.type_at_server(model_weights_type),
      tff.type_at_clients(tff.SequenceType(batch_type)))
  def server_eval(server_model_weights, federated_dataset):
    client_outputs = tff.federated_map(
        client_computation,
        [tff.federated_broadcast(server_model_weights), federated_dataset])
    return federated_output_computation(client_outputs)

  return server_eval