def test_execution_with_custom_dp_query(self):
    client_data = create_emnist_client_data()
    train_data = [client_data(), client_data()]

    def loss_fn():
      return tf.keras.losses.SparseCategoricalCrossentropy()

    def metrics_fn():
      return [
          NumExamplesCounter(),
          NumBatchesCounter(),
          tf.keras.metrics.SparseCategoricalAccuracy()
      ]

    # No values should be changed, but working with inf directly zeroes out all
    # updates. Preferring very large value, but one that can be handled in
    # multiplication/division
    gaussian_sum_query = tfp.GaussianSumQuery(l2_norm_clip=1e10, stddev=0)
    dp_sum_factory = tff.aggregators.DifferentiallyPrivateFactory(
        query=gaussian_sum_query,
        record_aggregation_factory=tff.aggregators.SumFactory())
    dp_mean_factory = _DPMean(dp_sum_factory)

    # Disable reconstruction via 0 learning rate to ensure post-recon loss
    # matches exact expectations round 0 and decreases by the next round.
    trainer = training_process.build_federated_reconstruction_process(
        MnistModel,
        loss_fn=loss_fn,
        metrics_fn=metrics_fn,
        server_optimizer_fn=functools.partial(tf.keras.optimizers.SGD, 0.01),
        client_optimizer_fn=functools.partial(tf.keras.optimizers.SGD, 0.001),
        reconstruction_optimizer_fn=functools.partial(tf.keras.optimizers.SGD,
                                                      0.0),
        aggregation_factory=dp_mean_factory,
    )
    state = trainer.initialize()

    outputs = []
    states = []
    for _ in range(2):
      state, output = trainer.next(state, train_data)
      outputs.append(output)
      states.append(state)

    # All weights and biases are initialized to 0, so initial logits are all 0
    # and softmax probabilities are uniform over 10 classes. So negative log
    # likelihood is -ln(1/10). This is on expectation, so increase tolerance.
    self.assertAllClose(outputs[0]['loss'], tf.math.log(10.0), rtol=1e-4)
    self.assertLess(outputs[1]['loss'], outputs[0]['loss'])
    self.assertNotAllClose(states[0].model.trainable, states[1].model.trainable)

    # Expect 6 reconstruction examples, 6 training examples. Only training
    # included in metrics.
    self.assertEqual(outputs[0]['num_examples_total'], 6.0)
    self.assertEqual(outputs[1]['num_examples_total'], 6.0)

    # Expect 4 reconstruction batches and 4 training batches. Only training
    # included in metrics.
    self.assertEqual(outputs[0]['num_batches_total'], 4.0)
    self.assertEqual(outputs[1]['num_batches_total'], 4.0)
  def test_keras_local_layer_metrics_none(self):

    def loss_fn():
      return tf.keras.losses.SparseCategoricalCrossentropy()

    it_process = training_process.build_federated_reconstruction_process(
        local_recon_model_fn,
        loss_fn=loss_fn,
        metrics_fn=None,
        client_optimizer_fn=functools.partial(tf.keras.optimizers.SGD, 0.001),
        reconstruction_optimizer_fn=functools.partial(tf.keras.optimizers.SGD,
                                                      0.001))

    server_state = it_process.initialize()

    client_data = create_emnist_client_data()
    federated_data = [client_data(), client_data()]

    server_states = []
    outputs = []
    for _ in range(2):
      server_state, output = it_process.next(server_state, federated_data)
      server_states.append(server_state)
      outputs.append(output)

    expected_keys = ['loss']
    self.assertCountEqual(outputs[0].keys(), expected_keys)
    self.assertLess(outputs[1]['loss'], outputs[0]['loss'])
    self.assertNotAllClose(server_states[0].model.trainable,
                           server_states[1].model.trainable)
    def test_custom_model_zeroing_clipping_aggregator_factory(self):
        client_data = create_emnist_client_data()
        train_data = [client_data(), client_data()]

        def loss_fn():
            return tf.keras.losses.SparseCategoricalCrossentropy()

        def metrics_fn():
            return [
                NumExamplesCounter(),
                NumBatchesCounter(),
                tf.keras.metrics.SparseCategoricalAccuracy()
            ]

        # No values should be clipped and zeroed
        aggregation_factory = tff.aggregators.zeroing_factory(
            zeroing_norm=float('inf'),
            inner_agg_factory=tff.aggregators.MeanFactory())

        # Disable reconstruction via 0 learning rate to ensure post-recon loss
        # matches exact expectations round 0 and decreases by the next round.
        trainer = training_process.build_federated_reconstruction_process(
            MnistModel,
            loss_fn=loss_fn,
            metrics_fn=metrics_fn,
            server_optimizer_fn=functools.partial(tf.keras.optimizers.SGD,
                                                  0.01),
            client_optimizer_fn=functools.partial(tf.keras.optimizers.SGD,
                                                  0.001),
            reconstruction_optimizer_fn=functools.partial(
                tf.keras.optimizers.SGD, 0.0),
            aggregation_factory=aggregation_factory,
        )
        state = trainer.initialize()

        outputs = []
        states = []
        for _ in range(2):
            state, output = trainer.next(state, train_data)
            outputs.append(output)
            states.append(state)

        # All weights and biases are initialized to 0, so initial logits are all 0
        # and softmax probabilities are uniform over 10 classes. So negative log
        # likelihood is -ln(1/10). This is on expectation, so increase tolerance.
        self.assertAllClose(outputs[0]['loss'], tf.math.log(10.0), rtol=1e-4)
        self.assertLess(outputs[1]['loss'], outputs[0]['loss'])
        self.assertNotAllClose(states[0].model.trainable,
                               states[1].model.trainable)

        # Expect 6 reconstruction examples, 6 training examples. Only training
        # included in metrics.
        self.assertEqual(outputs[0]['num_examples_total'], 6.0)
        self.assertEqual(outputs[1]['num_examples_total'], 6.0)

        # Expect 4 reconstruction batches and 4 training batches. Only training
        # included in metrics.
        self.assertEqual(outputs[0]['num_batches_total'], 4.0)
        self.assertEqual(outputs[1]['num_batches_total'], 4.0)
def iterative_process_builder(
    model_fn: Callable[[], reconstruction_model.ReconstructionModel],
    loss_fn: Callable[[], List[tf.keras.losses.Loss]],
    metrics_fn: Optional[Callable[[], List[tf.keras.metrics.Metric]]] = None,
    client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
    dataset_split_fn_builder: Callable[
        ..., reconstruction_utils.DatasetSplitFn] = reconstruction_utils
    .build_dataset_split_fn,
    task_name: str = 'stackoverflow_nwp',
    dp_noise_multiplier: Optional[float] = None,
    dp_zeroing: bool = True,
    clients_per_round: int = 5,
) -> tff.templates.IterativeProcess:
  """Creates an iterative process using a given TFF `model_fn`."""

  # Get aggregation factory for DP, if needed.
  aggregation_factory = None
  client_weighting = client_weight_fn
  if dp_noise_multiplier is not None:
    aggregation_factory = tff.learning.dp_aggregator(
        noise_multiplier=dp_noise_multiplier,
        clients_per_round=float(clients_per_round),
        zeroing=dp_zeroing)
    # DP is only implemented for uniform weighting.
    client_weighting = lambda _: 1.0

  if task_name == 'stackoverflow_nwp_finetune':
    # The returned iterative process would be basically the same as the one
    # created by the standard `tff.learning.build_federated_averaging_process`.
    client_epochs_per_round = 1

    # No need to split the client data as the model has only global variables.
    def dataset_split_fn(
        client_dataset: tf.data.Dataset,
        round_num: tf.Tensor) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
      del round_num
      return client_dataset.repeat(0), client_dataset.repeat(
          client_epochs_per_round)

  else:
    dataset_split_fn = dataset_split_fn_builder(
        recon_epochs_max=1,
        recon_epochs_constant=1,
        recon_steps_max=1,
        post_recon_epochs=1,
        post_recon_steps_max=1,
        split_dataset=False)

  return training_process.build_federated_reconstruction_process(
      model_fn=model_fn,
      loss_fn=loss_fn,
      metrics_fn=metrics_fn,
      server_optimizer_fn=lambda: tf.keras.optimizers.SGD(1.0),
      client_optimizer_fn=lambda: tf.keras.optimizers.SGD(1.0),
      reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(1.0),
      client_weight_fn=client_weighting,
      dataset_split_fn=dataset_split_fn,
      aggregation_factory=aggregation_factory)
    def test_custom_model_eval_reconstruction_disable_post_recon(self):
        """Ensures we can disable post-recon on a client via custom `DatasetSplitFn`."""
        client_data = create_emnist_client_data()
        train_data = [client_data(batch_size=3), client_data(batch_size=2)]

        def loss_fn():
            return tf.keras.losses.SparseCategoricalCrossentropy()

        def metrics_fn():
            return [
                NumExamplesCounter(),
                NumBatchesCounter(),
                tf.keras.metrics.SparseCategoricalAccuracy()
            ]

        def dataset_split_fn(client_dataset, round_num):
            del round_num
            recon_dataset = client_dataset.repeat(2)
            # One user gets 1 batch with 1 example, the other user gets 0 batches.
            post_recon_dataset = client_dataset.skip(1)
            return recon_dataset, post_recon_dataset

        trainer = training_process.build_federated_reconstruction_process(
            MnistModel,
            loss_fn=loss_fn,
            metrics_fn=metrics_fn,
            client_optimizer_fn=functools.partial(tf.keras.optimizers.SGD,
                                                  0.001),
            evaluate_reconstruction=True,
            jointly_train_variables=True,
            dataset_split_fn=dataset_split_fn)
        state = trainer.initialize()

        outputs = []
        states = []
        for _ in range(2):
            state, output = trainer.next(state, train_data)
            outputs.append(output)
            states.append(state)

        # One client should still have a delta that updates the global weights, so
        # there should be a change in the server state and loss should still
        # decrease.
        self.assertLess(outputs[1]['loss'], outputs[0]['loss'])
        self.assertNotAllClose(states[0].model.trainable,
                               states[1].model.trainable)

        # Expect 12 reconstruction examples, 1 training examples.
        self.assertEqual(outputs[0]['num_examples_total'], 13.0)
        self.assertEqual(outputs[1]['num_examples_total'], 13.0)

        # Expect 6 reconstruction batches and 1 training batches.
        self.assertEqual(outputs[0]['num_batches_total'], 7.0)
        self.assertEqual(outputs[1]['num_batches_total'], 7.0)
    def test_custom_model_eval_reconstruction_multiple_epochs(self):
        client_data = create_emnist_client_data()
        train_data = [client_data(), client_data()]

        def loss_fn():
            return tf.keras.losses.SparseCategoricalCrossentropy()

        def metrics_fn():
            return [
                NumExamplesCounter(),
                NumBatchesCounter(),
                tf.keras.metrics.SparseCategoricalAccuracy()
            ]

        dataset_split_fn = reconstruction_utils.build_dataset_split_fn(
            recon_epochs_max=3,
            recon_epochs_constant=False,
            post_recon_epochs=4,
            post_recon_steps_max=3)
        trainer = training_process.build_federated_reconstruction_process(
            MnistModel,
            loss_fn=loss_fn,
            metrics_fn=metrics_fn,
            client_optimizer_fn=functools.partial(tf.keras.optimizers.SGD,
                                                  0.001),
            reconstruction_optimizer_fn=functools.partial(
                tf.keras.optimizers.SGD, 0.001),
            evaluate_reconstruction=True,
            dataset_split_fn=dataset_split_fn)
        state = trainer.initialize()

        outputs = []
        states = []
        for _ in range(2):
            state, output = trainer.next(state, train_data)
            outputs.append(output)
            states.append(state)

        self.assertLess(outputs[1]['loss'], outputs[0]['loss'])
        self.assertNotAllClose(states[0].model.trainable,
                               states[1].model.trainable)

        # Expect 6 reconstruction examples, 10 training examples.
        self.assertEqual(outputs[0]['num_examples_total'], 16.0)
        # Expect 12 reconstruction examples, 10 training examples.
        self.assertEqual(outputs[1]['num_examples_total'], 22.0)

        # Expect 4 reconstruction batches and 6 training batches.
        self.assertEqual(outputs[0]['num_batches_total'], 10.0)
        # Expect 8 reconstruction batches and 6 training batches.
        self.assertEqual(outputs[1]['num_batches_total'], 14.0)
    def test_keras_eval_reconstruction_joint_training(self):
        def loss_fn():
            return tf.keras.losses.SparseCategoricalCrossentropy()

        def metrics_fn():
            return [
                NumExamplesCounter(),
                NumBatchesCounter(),
                tf.keras.metrics.SparseCategoricalAccuracy()
            ]

        it_process = training_process.build_federated_reconstruction_process(
            local_recon_model_fn,
            loss_fn=loss_fn,
            metrics_fn=metrics_fn,
            client_optimizer_fn=functools.partial(tf.keras.optimizers.SGD,
                                                  0.001),
            reconstruction_optimizer_fn=functools.partial(
                tf.keras.optimizers.SGD, 0.001),
            evaluate_reconstruction=True,
            jointly_train_variables=True)

        server_state = it_process.initialize()

        client_data = create_emnist_client_data()
        federated_data = [client_data(), client_data()]

        server_states = []
        outputs = []
        loss_list = []
        for _ in range(5):
            server_state, output = it_process.next(server_state,
                                                   federated_data)
            server_states.append(server_state)
            outputs.append(output)
            loss_list.append(output['loss'])

        expected_keys = [
            'sparse_categorical_accuracy', 'loss', 'num_examples_total',
            'num_batches_total'
        ]
        self.assertCountEqual(outputs[0].keys(), expected_keys)
        self.assertLess(np.mean(loss_list[2:]), np.mean(loss_list[:2]))
        self.assertNotAllClose(server_states[0].model.trainable,
                               server_states[1].model.trainable)
        self.assertEqual(outputs[0]['num_examples_total'], 12)
        self.assertEqual(outputs[1]['num_batches_total'], 8)
        self.assertEqual(outputs[0]['num_examples_total'], 12)
        self.assertEqual(outputs[1]['num_batches_total'], 8)
    def test_custom_model_eval_reconstruction_split_multiple_epochs(self):
        client_data = create_emnist_client_data()
        # 3 batches per user, each with one example. Since data will be split for
        # each user, each user will have 2 unique recon examples, and 1 unique
        # post-recon example (even-indices are allocated to recon during splitting).
        train_data = [client_data(batch_size=1), client_data(batch_size=1)]

        def loss_fn():
            return tf.keras.losses.SparseCategoricalCrossentropy()

        def metrics_fn():
            return [
                NumExamplesCounter(),
                NumBatchesCounter(),
                tf.keras.metrics.SparseCategoricalAccuracy()
            ]

        dataset_split_fn = reconstruction_utils.build_dataset_split_fn(
            recon_epochs_max=3, split_dataset=True, post_recon_epochs=5)
        trainer = training_process.build_federated_reconstruction_process(
            MnistModel,
            loss_fn=loss_fn,
            metrics_fn=metrics_fn,
            client_optimizer_fn=functools.partial(tf.keras.optimizers.SGD,
                                                  0.001),
            evaluate_reconstruction=True,
            dataset_split_fn=dataset_split_fn)
        state = trainer.initialize()

        outputs = []
        states = []
        for _ in range(2):
            state, output = trainer.next(state, train_data)
            outputs.append(output)
            states.append(state)

        self.assertLess(outputs[1]['loss'], outputs[0]['loss'])
        self.assertNotAllClose(states[0].model.trainable,
                               states[1].model.trainable)

        # Expect 12 reconstruction examples, 10 training examples.
        self.assertEqual(outputs[0]['num_examples_total'], 22.0)
        self.assertEqual(outputs[1]['num_examples_total'], 22.0)

        # Expect 12 reconstruction batches and 10 training batches.
        self.assertEqual(outputs[0]['num_batches_total'], 22.0)
        self.assertEqual(outputs[1]['num_batches_total'], 22.0)
  def test_fed_recon_with_custom_client_weight_fn(self):
    client_data = create_emnist_client_data()
    federated_data = [client_data()]

    def client_weight_fn(local_outputs):
      return 1.0 / (1.0 + local_outputs['loss'][-1])

    it_process = training_process.build_federated_reconstruction_process(
        local_recon_model_fn,
        loss_fn=tf.keras.losses.SparseCategoricalCrossentropy,
        client_optimizer_fn=functools.partial(tf.keras.optimizers.SGD, 0.001),
        reconstruction_optimizer_fn=functools.partial(tf.keras.optimizers.SGD,
                                                      0.001),
        client_weight_fn=client_weight_fn)

    _, train_outputs, _ = self._run_rounds(it_process, federated_data, 5)
    self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss'])
  def test_personal_matrix_factorization_trains_reconstruction_model(self):
    train_data = [
        self.train_users.flatten().tolist(),
        self.train_items.flatten().tolist(),
        self.train_preferences.flatten().tolist()
    ]
    train_tf_dataset = tf.data.Dataset.from_tensor_slices(
        list(zip(*train_data)))

    def batch_map_fn(example_batch):
      return collections.OrderedDict(
          x=tf.cast(example_batch[:, 0:1], tf.int64), y=example_batch[:, 1:2])

    train_tf_dataset = train_tf_dataset.batch(1).map(batch_map_fn).repeat(5)
    train_tf_datasets = [train_tf_dataset] * 2

    num_users = 1
    num_items = 8
    num_latent_factors = 10
    personal_model = True
    add_biases = False
    l2_regularization = 0.0

    tff_model_fn = models.build_reconstruction_model(
        functools.partial(
            models.get_matrix_factorization_model,
            num_users,
            num_items,
            num_latent_factors,
            personal_model=personal_model,
            add_biases=add_biases,
            l2_regularization=l2_regularization))

    # Also test `models.get_loss_fn` and `models.get_metrics_fn`.
    trainer = training_process.build_federated_reconstruction_process(
        tff_model_fn,
        loss_fn=models.get_loss_fn(),
        metrics_fn=models.get_metrics_fn(),
        client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1e-2),
        reconstruction_optimizer_fn=(
            lambda: tf.keras.optimizers.SGD(learning_rate=1e-3)),
        dataset_split_fn=reconstruction_utils.build_dataset_split_fn(
            recon_epochs_max=10))

    state = trainer.initialize()
    trainer.next(state, train_tf_datasets)
  def test_server_update_with_inf_weight_is_noop(self):
    client_data = create_emnist_client_data()
    federated_data = [client_data()]
    client_weight_fn = lambda x: np.inf

    it_process = training_process.build_federated_reconstruction_process(
        local_recon_model_fn,
        loss_fn=tf.keras.losses.SparseCategoricalCrossentropy,
        client_optimizer_fn=functools.partial(tf.keras.optimizers.SGD, 0.001),
        reconstruction_optimizer_fn=functools.partial(tf.keras.optimizers.SGD,
                                                      0.001),
        client_weight_fn=client_weight_fn)

    state, _, initial_state = self._run_rounds(it_process, federated_data, 1)
    self.assertAllClose(state.model.trainable, initial_state.model.trainable,
                        1e-8)
    self.assertAllClose(state.model.trainable, initial_state.model.trainable,
                        1e-8)
    def test_iterative_process_builds_with_dp_agg_and_client_weight_fn(self):
        def loss_fn():
            return tf.keras.losses.SparseCategoricalCrossentropy()

        def metrics_fn():
            return [
                NumExamplesCounter(),
                NumBatchesCounter(),
                tf.keras.metrics.SparseCategoricalAccuracy()
            ]

        # No values should be changed, but working with inf directly zeroes out all
        # updates. Preferring very large value, but one that can be handled in
        # multiplication/division
        gaussian_sum_query = tfp.GaussianSumQuery(l2_norm_clip=1e10, stddev=0)
        dp_sum_factory = tff.aggregators.DifferentiallyPrivateFactory(
            query=gaussian_sum_query,
            record_aggregation_factory=tff.aggregators.SumFactory())
        dp_mean_factory = _DPMean(dp_sum_factory)

        def client_weight_fn(local_outputs):
            del local_outputs  # Unused
            return 1.0

        # Ensure this builds, as some builders raise if an unweighted aggregation is
        # specified with a client_weight_fn.
        trainer = training_process.build_federated_reconstruction_process(
            MnistModel,
            loss_fn=loss_fn,
            metrics_fn=metrics_fn,
            server_optimizer_fn=functools.partial(tf.keras.optimizers.SGD,
                                                  0.01),
            client_optimizer_fn=functools.partial(tf.keras.optimizers.SGD,
                                                  0.001),
            reconstruction_optimizer_fn=functools.partial(
                tf.keras.optimizers.SGD, 0.0),
            aggregation_factory=dp_mean_factory,
            client_weight_fn=client_weight_fn,
        )
        self.assertIsInstance(trainer, tff.templates.IterativeProcess)
  def test_build_train_iterative_process(self):

    def loss_fn():
      return tf.keras.losses.SparseCategoricalCrossentropy()

    def metrics_fn():
      return [
          NumExamplesCounter(),
          NumBatchesCounter(),
          tf.keras.metrics.SparseCategoricalAccuracy()
      ]

    it_process = training_process.build_federated_reconstruction_process(
        local_recon_model_fn,
        loss_fn=loss_fn,
        metrics_fn=metrics_fn,
        client_optimizer_fn=functools.partial(tf.keras.optimizers.SGD, 0.1))

    self.assertIsInstance(it_process, tff.templates.IterativeProcess)
    federated_data_type = it_process.next.type_signature.parameter[1]
    self.assertEqual(
        str(federated_data_type), '{<x=float32[?,784],y=int32[?,1]>*}@CLIENTS')
  def test_get_model_weights(self):
    client_data = create_emnist_client_data()
    federated_data = [client_data()]

    it_process = training_process.build_federated_reconstruction_process(
        local_recon_model_fn,
        loss_fn=tf.keras.losses.SparseCategoricalCrossentropy,
        client_optimizer_fn=functools.partial(tf.keras.optimizers.SGD, 0.001),
        reconstruction_optimizer_fn=functools.partial(tf.keras.optimizers.SGD,
                                                      0.001))
    state = it_process.initialize()

    self.assertIsInstance(
        it_process.get_model_weights(state), tff.learning.ModelWeights)
    self.assertAllClose(state.model.trainable,
                        it_process.get_model_weights(state).trainable)

    for _ in range(3):
      state, _ = it_process.next(state, federated_data)
      self.assertIsInstance(
          it_process.get_model_weights(state), tff.learning.ModelWeights)
      self.assertAllClose(state.model.trainable,
                          it_process.get_model_weights(state).trainable)
Example #15
0
    def iterative_process_builder(
        model_fn: Callable[[], reconstruction_model.ReconstructionModel],
        loss_fn: Callable[[], List[tf.keras.losses.Loss]],
        metrics_fn: Optional[Callable[[],
                                      List[tf.keras.metrics.Metric]]] = None,
        client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
        dataset_split_fn_builder: Callable[
            ..., reconstruction_utils.DatasetSplitFn] = reconstruction_utils.
        build_dataset_split_fn,
    ) -> tff.templates.IterativeProcess:
        """Creates an iterative process using a given TFF `model_fn`.

    For a `stackoverflow_nwp_finetune` task, the `model_fn` must return a model
    that has only global variables, and the argument `dataset_split_fn_builder`
    is ignored. The returned iterative process is basically the same as the one
    created by the standard `tff.learning.build_federated_averaging_process`.

    For other tasks, the returned iterative process performs the federated
    reconstruction algorithm defined by
    `training_process.build_federated_reconstruction_process`.

    Args:
      model_fn: A no-arg function returning a
        `reconstruction_model.ReconstructionModel`. The returned model must have
        only global variables for a `stackoverflow_nwp_finetune` task.
      loss_fn: A no-arg function returning a list of `tf.keras.losses.Loss`.
      metrics_fn: A no-arg function returning a list of
        `tf.keras.metrics.Metric`.
      client_weight_fn: Optional function that takes the local model's output,
        and returns a tensor that provides the weight in the federated average
        of model deltas. If not provided, the default is the total number of
        examples processed on device. If DP is used, this argument is ignored,
        and uniform client weighting is used.
      dataset_split_fn_builder: `DatasetSplitFn` builder. Returns a method used
        to split the examples into a reconstruction, and post-reconstruction
        set. Ignored for a `stackoverflow_nwp_finetune` task.

    Raises:
      ValueError: if `model_fn` returns a model with local variables for a
        `stackoverflow_nwp_finetune` task.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

        # Get aggregation factory for DP, if needed.
        aggregation_factory = None
        client_weighting = client_weight_fn
        if FLAGS.dp_noise_multiplier is not None:
            aggregation_factory = tff.learning.dp_aggregator(
                noise_multiplier=FLAGS.dp_noise_multiplier,
                clients_per_round=float(FLAGS.clients_per_round),
                zeroing=FLAGS.dp_zeroing)
            # DP is only implemented for uniform weighting.
            client_weighting = lambda _: 1.0

        if FLAGS.task == 'stackoverflow_nwp_finetune':

            if not reconstruction_utils.has_only_global_variables(model_fn()):
                raise ValueError(
                    '`model_fn` should return a model with only global variables. '
                )

            def fake_dataset_split_fn(
                client_dataset: tf.data.Dataset, round_num: tf.Tensor
            ) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
                del round_num
                return client_dataset.repeat(0), client_dataset.repeat(
                    FLAGS.client_epochs_per_round)

            return training_process.build_federated_reconstruction_process(
                model_fn=model_fn,
                loss_fn=loss_fn,
                metrics_fn=metrics_fn,
                server_optimizer_fn=lambda: server_optimizer_fn(
                    FLAGS.server_learning_rate),
                client_optimizer_fn=lambda: client_optimizer_fn(
                    FLAGS.client_learning_rate),
                dataset_split_fn=fake_dataset_split_fn,
                client_weight_fn=client_weighting,
                aggregation_factory=aggregation_factory)

        return training_process.build_federated_reconstruction_process(
            model_fn=model_fn,
            loss_fn=loss_fn,
            metrics_fn=metrics_fn,
            server_optimizer_fn=lambda: server_optimizer_fn(
                FLAGS.server_learning_rate),
            client_optimizer_fn=lambda: client_optimizer_fn(
                FLAGS.client_learning_rate),
            reconstruction_optimizer_fn=functools.partial(
                reconstruction_optimizer_fn,
                FLAGS.reconstruction_learning_rate),
            dataset_split_fn=dataset_split_fn_builder(
                recon_epochs_max=FLAGS.recon_epochs_max,
                recon_epochs_constant=FLAGS.recon_epochs_constant,
                recon_steps_max=FLAGS.recon_steps_max,
                post_recon_epochs=FLAGS.post_recon_epochs,
                post_recon_steps_max=FLAGS.post_recon_steps_max,
                split_dataset=FLAGS.split_dataset),
            evaluate_reconstruction=FLAGS.evaluate_reconstruction,
            jointly_train_variables=FLAGS.jointly_train_variables,
            client_weight_fn=client_weighting,
            aggregation_factory=aggregation_factory)