def test_build_dataset_split_fn_recon_max_steps(self):
        # 3 batches.
        client_dataset = tf.data.Dataset.range(6).batch(2)

        split_dataset_fn = reconstruction_utils.build_dataset_split_fn(
            recon_epochs=2, recon_steps_max=4)
        recon_dataset, post_recon_dataset = split_dataset_fn(client_dataset)

        recon_list = list(recon_dataset.as_numpy_iterator())
        post_recon_list = list(post_recon_dataset.as_numpy_iterator())

        self.assertAllEqual(recon_list, [[0, 1], [2, 3], [4, 5], [0, 1]])
        self.assertAllEqual(post_recon_list, [[0, 1], [2, 3], [4, 5]])

        # Adding more steps than the number of actual steps has no effect.
        split_dataset_fn = reconstruction_utils.build_dataset_split_fn(
            recon_epochs=2, recon_steps_max=7)
        recon_dataset, post_recon_dataset = split_dataset_fn(client_dataset)

        recon_list = list(recon_dataset.as_numpy_iterator())
        post_recon_list = list(post_recon_dataset.as_numpy_iterator())

        self.assertAllEqual(recon_list,
                            [[0, 1], [2, 3], [4, 5], [0, 1], [2, 3], [4, 5]])
        self.assertAllEqual(post_recon_list, [[0, 1], [2, 3], [4, 5]])
Beispiel #2
0
  def test_federated_reconstruction_metrics_none_loss_decreases(self, model_fn):

    def loss_fn():
      return tf.keras.losses.MeanSquaredError()

    dataset_split_fn = reconstruction_utils.build_dataset_split_fn(
        recon_epochs=3)

    evaluate = evaluation_computation.build_federated_evaluation(
        model_fn,
        loss_fn=loss_fn,
        reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.01),
        dataset_split_fn=dataset_split_fn)
    self.assertEqual(
        str(evaluate.type_signature),
        '(<server_model_weights=<trainable=<float32[1,1]>,'
        'non_trainable=<>>@SERVER,federated_dataset={<x=float32[?,1],'
        'y=float32[?,1]>*}@CLIENTS> -> <broadcast=<>,eval='
        '<loss=float32>>@SERVER)')

    result = evaluate(
        collections.OrderedDict([
            ('trainable', [[[1.0]]]),
            ('non_trainable', []),
        ]), create_client_data())
    eval_result = result['eval']

    # Ensure loss decreases from reconstruction vs. initializing the bias to 0.
    # MSE is (y - 1 * x)^2 for each example, for a mean of
    # (4^2 + 4^2 + 5^2 + 4^2 + 3^2 + 6^2) / 6 = 59/3.
    self.assertLess(eval_result['loss'], 19.666666)
    def test_build_dataset_split_fn_split_dataset_even_batches(self):
        # 4 batches.
        client_dataset = tf.data.Dataset.range(8).batch(2)

        split_dataset_fn = reconstruction_utils.build_dataset_split_fn(
            split_dataset=True)
        recon_dataset, post_recon_dataset = split_dataset_fn(client_dataset)

        recon_list = list(recon_dataset.as_numpy_iterator())
        post_recon_list = list(post_recon_dataset.as_numpy_iterator())

        self.assertAllEqual(recon_list, [[0, 1], [4, 5]])
        self.assertAllEqual(post_recon_list, [[2, 3], [6, 7]])
    def test_build_dataset_split_fn_post_recon_multiple_epochs_max_steps(self):
        # 3 batches.
        client_dataset = tf.data.Dataset.range(6).batch(2)

        split_dataset_fn = reconstruction_utils.build_dataset_split_fn(
            post_recon_epochs=2, post_recon_steps_max=4)
        recon_dataset, post_recon_dataset = split_dataset_fn(client_dataset)

        recon_list = list(recon_dataset.as_numpy_iterator())
        post_recon_list = list(post_recon_dataset.as_numpy_iterator())

        self.assertAllEqual(recon_list, [[0, 1], [2, 3], [4, 5]])
        self.assertAllEqual(post_recon_list, [[0, 1], [2, 3], [4, 5], [0, 1]])
    def test_build_dataset_split_fn_split_dataset_one_batch(self):
        """Ensures clients without any data don't fail."""
        # 1 batch. Batch size can be larger than number of examples.
        client_dataset = tf.data.Dataset.range(1).batch(4)

        split_dataset_fn = reconstruction_utils.build_dataset_split_fn(
            split_dataset=True)
        recon_dataset, post_recon_dataset = split_dataset_fn(client_dataset)

        recon_list = list(recon_dataset.as_numpy_iterator())
        post_recon_list = list(post_recon_dataset.as_numpy_iterator())

        self.assertAllEqual(recon_list, [[0]])
        self.assertAllEqual(post_recon_list, [])
    def test_build_dataset_split_fn_split_dataset_zero_batches(self):
        """Ensures clients without any data don't fail."""
        # 0 batches.
        client_dataset = tf.data.Dataset.range(0).batch(2)

        split_dataset_fn = reconstruction_utils.build_dataset_split_fn(
            split_dataset=True)
        recon_dataset, post_recon_dataset = split_dataset_fn(client_dataset)

        recon_list = list(recon_dataset.as_numpy_iterator())
        post_recon_list = list(post_recon_dataset.as_numpy_iterator())

        self.assertAllEqual(recon_list, [])
        self.assertAllEqual(post_recon_list, [])
    def test_custom_model_multiple_epochs(self, optimizer_fn):
        client_data = create_emnist_client_data()
        train_data = [client_data(), client_data()]

        def loss_fn():
            return tf.keras.losses.SparseCategoricalCrossentropy()

        def metrics_fn():
            return [
                counters.NumExamplesCounter(),
                counters.NumBatchesCounter(),
                tf.keras.metrics.SparseCategoricalAccuracy()
            ]

        dataset_split_fn = reconstruction_utils.build_dataset_split_fn(
            recon_epochs=3, post_recon_epochs=4, post_recon_steps_max=3)
        trainer = training_process.build_training_process(
            MnistModel,
            loss_fn=loss_fn,
            metrics_fn=metrics_fn,
            client_optimizer_fn=optimizer_fn(0.001),
            reconstruction_optimizer_fn=optimizer_fn(0.001),
            dataset_split_fn=dataset_split_fn)
        state = trainer.initialize()

        outputs = []
        states = []
        for _ in range(2):
            state, output = trainer.next(state, train_data)
            outputs.append(output)
            states.append(state)

        self.assertLess(outputs[1]['train']['loss'],
                        outputs[0]['train']['loss'])
        self.assertNotAllClose(states[0].model.trainable,
                               states[1].model.trainable)

        self.assertEqual(outputs[0]['train']['num_examples'], 10.0)
        self.assertEqual(outputs[1]['train']['num_examples'], 10.0)
        self.assertEqual(outputs[0]['train']['num_batches'], 6.0)
        self.assertEqual(outputs[1]['train']['num_batches'], 6.0)
Beispiel #8
0
  def test_federated_reconstruction_recon_lr_0(self, model_fn):

    def loss_fn():
      return tf.keras.losses.MeanSquaredError()

    def metrics_fn():
      return [NumExamplesCounter(), NumOverCounter(5.0)]

    dataset_split_fn = reconstruction_utils.build_dataset_split_fn()

    evaluate = evaluation_computation.build_federated_evaluation(
        model_fn,
        loss_fn=loss_fn,
        metrics_fn=metrics_fn,
        # Set recon optimizer LR to 0 so reconstruction has no effect.
        reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.0),
        dataset_split_fn=dataset_split_fn)
    self.assertEqual(
        str(evaluate.type_signature),
        '(<server_model_weights=<trainable=<float32[1,1]>,'
        'non_trainable=<>>@SERVER,federated_dataset={<x=float32[?,1],'
        'y=float32[?,1]>*}@CLIENTS> -> <broadcast=<>,eval='
        '<loss=float32,num_examples_total=float32,num_over=float32>>@SERVER)')

    result = evaluate(
        collections.OrderedDict([
            ('trainable', [[[1.0]]]),
            ('non_trainable', []),
        ]), create_client_data())
    eval_result = result['eval']

    # Now have an expectation for loss since the local bias is initialized at 0
    # and not reconstructed. MSE is (y - 1 * x)^2 for each example, for a mean
    # of (4^2 + 4^2 + 5^2 + 4^2 + 3^2 + 6^2) / 6 = 59/3.
    self.assertAlmostEqual(eval_result['loss'], 19.666666)
    self.assertAlmostEqual(eval_result['num_examples_total'], 6.0)
    self.assertAlmostEqual(eval_result['num_over'], 3.0)
Beispiel #9
0
  def test_federated_reconstruction_split_data_multiple_epochs(self, model_fn):

    def loss_fn():
      return tf.keras.losses.MeanSquaredError()

    def metrics_fn():
      return [NumExamplesCounter(), NumOverCounter(5.0)]

    dataset_split_fn = reconstruction_utils.build_dataset_split_fn(
        recon_epochs=2,
        post_recon_epochs=10,
        post_recon_steps_max=7,
        split_dataset=True)

    evaluate = evaluation_computation.build_federated_evaluation(
        model_fn,
        loss_fn=loss_fn,
        metrics_fn=metrics_fn,
        reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1),
        dataset_split_fn=dataset_split_fn)
    self.assertEqual(
        str(evaluate.type_signature),
        '(<server_model_weights=<trainable=<float32[1,1]>,'
        'non_trainable=<>>@SERVER,federated_dataset={<x=float32[?,1],'
        'y=float32[?,1]>*}@CLIENTS> -> <broadcast=<>,eval='
        '<loss=float32,num_examples_total=float32,num_over=float32>>@SERVER)')

    result = evaluate(
        collections.OrderedDict([
            ('trainable', [[[5.0]]]),
            ('non_trainable', []),
        ]), create_client_data())
    eval_result = result['eval']

    self.assertAlmostEqual(eval_result['num_examples_total'], 14.0)
    self.assertAlmostEqual(eval_result['num_over'], 7.0)
def build_federated_evaluation(
    model_fn: training_process.ModelFn,
    *,  # Callers pass below args by name.
    loss_fn: training_process.LossFn,
    metrics_fn: Optional[training_process.MetricsFn] = None,
    reconstruction_optimizer_fn: training_process.OptimizerFn = functools.
    partial(tf.keras.optimizers.SGD, 0.1),
    dataset_split_fn: Optional[reconstruction_utils.DatasetSplitFn] = None,
    broadcast_process: Optional[measured_process_lib.MeasuredProcess] = None,
) -> computation_base.Computation:
    """Builds a `tff.Computation` for evaluating a reconstruction `Model`.

  The returned computation proceeds in two stages: (1) reconstruction and (2)
  evaluation. During the reconstruction stage, local variables are reconstructed
  by freezing global variables and training using `reconstruction_optimizer_fn`.
  During the evaluation stage, the reconstructed local variables and global
  variables are evaluated using the provided `loss_fn` and `metrics_fn`.

  Usage of returned computation:
    eval_comp = build_federated_evaluation(...)
    metrics = eval_comp(tff.learning.reconstruction.get_global_variables(model),
                        federated_data)

  Args:
    model_fn: A no-arg function that returns a
      `tff.learning.reconstruction.Model`. This method must *not* capture
      Tensorflow tensors or variables and use them. Must be constructed entirely
      from scratch on each invocation, returning the same pre-constructed model
      each call will result in an error.
    loss_fn: A no-arg function returning a `tf.keras.losses.Loss` to use to
      reconstruct and evaluate the model. The loss will be applied to the
      model's outputs during the evaluation stage. The final loss metric is the
      example-weighted mean loss across batches (and across clients).
    metrics_fn: A no-arg function returning a list of `tf.keras.metrics.Metric`s
      to evaluate the model. The metrics will be applied to the model's outputs
      during the evaluation stage. Final metric values are the example-weighted
      mean of metric values across batches (and across clients). If None, no
      metrics are applied.
    reconstruction_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer` used to reconstruct the local variables
      with the global ones frozen.
    dataset_split_fn: A `tff.learning.reconstruction.DatasetSplitFn` taking in a
      single TF dataset and producing two TF datasets. The first is iterated
      over during reconstruction, and the second is iterated over during
      evaluation. This can be used to preprocess datasets to e.g. iterate over
      them for multiple epochs or use disjoint data for reconstruction and
      evaluation. If None, split client data in half for each user, using one
      half for reconstruction and the other for evaluation. See
      `tff.learning.reconstruction.build_dataset_split_fn` for options.
    broadcast_process: A `tff.templates.MeasuredProcess` that broadcasts the
      model weights on the server to the clients. It must support the signature
      `(input_values@SERVER -> output_values@CLIENT)` and have empty state. If
      set to default None, the server model is broadcast to the clients using
      the default `tff.federated_broadcast`.

  Raises:
    TypeError: if `broadcast_process` does not have the expected signature or
      has non-empty state.

  Returns:
    A `tff.Computation` that accepts global model parameters and federated data
    and returns example-weighted evaluation loss and metrics.
  """
    # Construct the model first just to obtain the metadata and define all the
    # types needed to define the computations that follow.
    with tf.Graph().as_default():
        model = model_fn()
        global_weights = reconstruction_utils.get_global_variables(model)
        model_weights_type = type_conversions.type_from_tensors(global_weights)
        batch_type = computation_types.to_type(model.input_spec)
        metrics = [keras_utils.MeanLossMetric(loss_fn())]
        if metrics_fn is not None:
            metrics.extend(metrics_fn())
        federated_output_computation = (
            keras_utils.federated_output_computation_from_metrics(metrics))
        # Remove unneeded variables to avoid polluting namespace.
        del model
        del global_weights
        del metrics

    if dataset_split_fn is None:
        dataset_split_fn = reconstruction_utils.build_dataset_split_fn(
            split_dataset=True)

    if broadcast_process is None:
        broadcast_process = optimizer_utils.build_stateless_broadcaster(
            model_weights_type=model_weights_type)
    if not optimizer_utils.is_valid_broadcast_process(broadcast_process):
        raise TypeError(
            'broadcast_process type signature does not conform to expected '
            'signature (<state@S, input@S> -> <state@S, result@C, measurements@S>).'
            ' Got: {t}'.format(t=broadcast_process.next.type_signature))
    if iterative_process.is_stateful(broadcast_process):
        raise TypeError(
            f'Eval broadcast_process must be stateless (have an empty '
            'state), has state '
            f'{broadcast_process.initialize.type_signature.result!r}')

    @tensorflow_computation.tf_computation(
        model_weights_type, computation_types.SequenceType(batch_type))
    def client_computation(incoming_model_weights: computation_types.Type,
                           client_dataset: computation_types.SequenceType):
        """Reconstructs and evaluates with `incoming_model_weights`."""
        client_model = model_fn()
        client_global_weights = reconstruction_utils.get_global_variables(
            client_model)
        client_local_weights = reconstruction_utils.get_local_variables(
            client_model)
        metrics = [keras_utils.MeanLossMetric(loss_fn())]
        if metrics_fn is not None:
            metrics.extend(metrics_fn())
        client_loss = loss_fn()
        reconstruction_optimizer = reconstruction_optimizer_fn()

        @tf.function
        def reconstruction_reduce_fn(num_examples_sum, batch):
            """Runs reconstruction training on local client batch."""
            with tf.GradientTape() as tape:
                output = client_model.forward_pass(batch, training=True)
                batch_loss = client_loss(y_true=output.labels,
                                         y_pred=output.predictions)

            gradients = tape.gradient(batch_loss,
                                      client_local_weights.trainable)
            reconstruction_optimizer.apply_gradients(
                zip(gradients, client_local_weights.trainable))
            return num_examples_sum + output.num_examples

        @tf.function
        def evaluation_reduce_fn(num_examples_sum, batch):
            """Runs evaluation on client batch without training."""
            output = client_model.forward_pass(batch, training=False)
            # Update each metric.
            for metric in metrics:
                metric.update_state(y_true=output.labels,
                                    y_pred=output.predictions)
            return num_examples_sum + output.num_examples

        @tf.function
        def tf_client_computation(incoming_model_weights, client_dataset):
            """Reconstructs and evaluates with `incoming_model_weights`."""
            recon_dataset, eval_dataset = dataset_split_fn(client_dataset)

            # Assign incoming global weights to `client_model` before reconstruction.
            tf.nest.map_structure(lambda v, t: v.assign(t),
                                  client_global_weights,
                                  incoming_model_weights)

            recon_dataset.reduce(tf.constant(0), reconstruction_reduce_fn)
            eval_dataset.reduce(tf.constant(0), evaluation_reduce_fn)

            eval_local_outputs = keras_utils.read_metric_variables(metrics)
            return eval_local_outputs

        return tf_client_computation(incoming_model_weights, client_dataset)

    @federated_computation.federated_computation(
        computation_types.at_server(model_weights_type),
        computation_types.at_clients(
            computation_types.SequenceType(batch_type)))
    def server_eval(server_model_weights: computation_types.FederatedType,
                    federated_dataset: computation_types.FederatedType):
        broadcast_output = broadcast_process.next(
            broadcast_process.initialize(), server_model_weights)
        client_outputs = intrinsics.federated_map(
            client_computation, [broadcast_output.result, federated_dataset])
        aggregated_client_outputs = federated_output_computation(
            client_outputs)
        measurements = intrinsics.federated_zip(
            collections.OrderedDict(broadcast=broadcast_output.measurements,
                                    eval=aggregated_client_outputs))
        return measurements

    return server_eval
Beispiel #11
0
def build_training_process(
    model_fn: ModelFn,
    *,  # Callers pass below args by name.
    loss_fn: LossFn,
    metrics_fn: Optional[MetricsFn] = None,
    server_optimizer_fn: OptimizerFn = functools.partial(
        tf.keras.optimizers.SGD, 1.0),
    client_optimizer_fn: OptimizerFn = functools.partial(
        tf.keras.optimizers.SGD, 0.1),
    reconstruction_optimizer_fn: OptimizerFn = functools.partial(
        tf.keras.optimizers.SGD, 0.1),
    dataset_split_fn: Optional[reconstruction_utils.DatasetSplitFn] = None,
    client_weighting: Optional[client_weight_lib.ClientWeightType] = None,
    broadcast_process: Optional[measured_process_lib.MeasuredProcess] = None,
    aggregation_factory: Optional[AggregationFactory] = None,
) -> iterative_process_lib.IterativeProcess:
    """Builds the IterativeProcess for optimization using FedRecon.

  Returns a `tff.templates.IterativeProcess` for Federated Reconstruction. On
  the client, computation can be divided into two stages: (1) reconstruction of
  local variables and (2) training of global variables.

  Args:
    model_fn: A no-arg function that returns a
      `tff.learning.reconstruction.Model`. This method must *not* capture
      Tensorflow tensors or variables and use them. must be constructed entirely
      from scratch on each invocation, returning the same pre-constructed model
      each call will result in an error.
    loss_fn: A no-arg function returning a `tf.keras.losses.Loss` to use to
      compute local model updates during reconstruction and post-reconstruction
      and evaluate the model during training. The final loss metric is the
      example-weighted mean loss across batches and across clients. The loss
      metric does not include reconstruction batches in the loss.
    metrics_fn: A no-arg function returning a list of `tf.keras.metrics.Metric`s
      to evaluate the model. Metrics results are computed locally as described
      by the metric, and are aggregated across clients as in
      `federated_aggregate_keras_metric`. If None, no metrics are applied.
      Metrics are not computed on reconstruction batches.
    server_optimizer_fn:  A `tff.learning.optimizers.Optimizer`, or a no-arg
      function that returns a `tf.keras.optimizers.Optimizer` for applying
      updates to the global model on the server.
    client_optimizer_fn: A `tff.learning.optimizers.Optimizer`, or a no-arg
      function that returns a `tf.keras.optimizers.Optimizer` for local client
      training after reconstruction.
    reconstruction_optimizer_fn: A `tff.learning.optimizers.Optimizer`, or a
      no-arg function that returns a `tf.keras.optimizers.Optimizer` used to
      reconstruct the local variables, with the global ones frozen, or the first
      stage described above.
    dataset_split_fn: A `reconstruction_utils.DatasetSplitFn` taking in a single
      TF dataset and producing two TF datasets. The first is iterated over
      during reconstruction, and the second is iterated over
      post-reconstruction. This can be used to preprocess datasets to e.g.
      iterate over them for multiple epochs or use disjoint data for
      reconstruction and post-reconstruction. If None, split client data in half
      for each user, using one half for reconstruction and the other for
      evaluation. See `reconstruction_utils.build_dataset_split_fn` for options.
    client_weighting: A value of `tff.learning.ClientWeighting` that specifies a
      built-in weighting method, or a callable that takes the local metrics of
      the model and returns a tensor that provides the weight in the federated
      average of model deltas. If None, defaults to weighting by number of
      examples.
    broadcast_process: A `tff.templates.MeasuredProcess` that broadcasts the
      model weights on the server to the clients. It must support the signature
      `(input_values@SERVER -> output_values@CLIENT)`. If set to default None,
      the server model is broadcast to the clients using the default
      `tff.federated_broadcast`.
    aggregation_factory: An optional instance of
      `tff.aggregators.WeightedAggregationFactory` or
      `tff.aggregators.UnweightedAggregationFactory` determining the method of
      aggregation to perform. If unspecified, uses a default
      `tff.aggregators.MeanFactory` which computes a stateless mean across
      clients (weighted depending on `client_weighting`).

  Raises:
    TypeError: If `broadcast_process` does not have the expected signature.
    TypeError: If `aggregation_factory` does not have the expected signature.
    ValueError: If  `aggregation_factory` is not a
      `tff.aggregators.WeightedAggregationFactory` or a
      `tff.aggregators.UnweightedAggregationFactory`.
    ValueError: If `aggregation_factory` is a
      `tff.aggregators.UnweightedAggregationFactory` but `client_weighting` is
      not `tff.learning.ClientWeighting.UNIFORM`.

  Returns:
    A `tff.templates.IterativeProcess`.
  """
    with tf.Graph().as_default():
        throwaway_model_for_metadata = model_fn()

    model_weights_type = type_conversions.type_from_tensors(
        reconstruction_utils.get_global_variables(
            throwaway_model_for_metadata))

    if client_weighting is None:
        client_weighting = client_weight_lib.ClientWeighting.NUM_EXAMPLES
    if (isinstance(aggregation_factory, factory.UnweightedAggregationFactory)
            and client_weighting
            is not client_weight_lib.ClientWeighting.UNIFORM):
        raise ValueError(
            f'Expected `tff.learning.ClientWeighting.UNIFORM` client '
            f'weighting with unweighted aggregator, instead got '
            f'{client_weighting}')

    if broadcast_process is None:
        broadcast_process = optimizer_utils.build_stateless_broadcaster(
            model_weights_type=model_weights_type)
    if not _is_valid_broadcast_process(broadcast_process):
        raise TypeError(
            'broadcast_process type signature does not conform to expected '
            'signature (<state@S, input@S> -> <state@S, result@C, measurements@S>).'
            ' Got: {t}'.format(t=broadcast_process.next.type_signature))
    broadcaster_state_type = (
        broadcast_process.initialize.type_signature.result.member)

    aggregation_process = _instantiate_aggregation_process(
        aggregation_factory, model_weights_type)
    aggregator_state_type = (
        aggregation_process.initialize.type_signature.result.member)

    server_init_tff = _build_server_init_fn(model_fn, server_optimizer_fn,
                                            aggregation_process.initialize,
                                            broadcast_process.initialize)
    server_state_type = server_init_tff.type_signature.result.member

    server_update_fn = _build_server_update_fn(
        model_fn,
        server_optimizer_fn,
        server_state_type,
        server_state_type.model,
        aggregator_state_type=aggregator_state_type,
        broadcaster_state_type=broadcaster_state_type)

    dataset_type = computation_types.SequenceType(
        throwaway_model_for_metadata.input_spec)
    if dataset_split_fn is None:
        dataset_split_fn = reconstruction_utils.build_dataset_split_fn(
            split_dataset=True)
    client_update_fn = _build_client_update_fn(
        model_fn,
        loss_fn=loss_fn,
        metrics_fn=metrics_fn,
        dataset_type=dataset_type,
        model_weights_type=server_state_type.model,
        client_optimizer_fn=client_optimizer_fn,
        reconstruction_optimizer_fn=reconstruction_optimizer_fn,
        dataset_split_fn=dataset_split_fn,
        client_weighting=client_weighting)

    federated_server_state_type = computation_types.at_server(
        server_state_type)
    federated_dataset_type = computation_types.at_clients(dataset_type)
    # Create placeholder metrics to produce a corresponding federated output
    # computation.
    metrics = []
    if metrics_fn is not None:
        metrics.extend(metrics_fn())
    metrics.append(keras_utils.MeanLossMetric(loss_fn()))
    federated_output_computation = (
        keras_utils.federated_output_computation_from_metrics(metrics))

    run_one_round_tff = _build_run_one_round_fn(
        server_update_fn,
        client_update_fn,
        federated_output_computation,
        federated_server_state_type,
        federated_dataset_type,
        aggregation_process=aggregation_process,
        broadcast_process=broadcast_process,
    )

    process = iterative_process_lib.IterativeProcess(
        initialize_fn=server_init_tff, next_fn=run_one_round_tff)

    @computations.tf_computation(server_state_type)
    def get_model_weights(server_state):
        return server_state.model

    process.get_model_weights = get_model_weights
    return process