def test_construction_fails_with_invalid_aggregation_factory(self):
   aggregation_factory = sampling.UnweightedReservoirSamplingFactory(
       sample_size=1)
   with self.assertRaisesRegex(
       TypeError, 'does not produce a compatible `AggregationProcess`'):
     optimizer_utils.build_model_delta_optimizer_process(
         model_fn=model_examples.LinearRegression,
         model_to_client_delta_fn=DummyClientDeltaFn,
         server_optimizer_fn=tf.keras.optimizers.SGD,
         model_update_aggregation_factory=aggregation_factory)
Exemple #2
0
    def test_initial_weights_pulled_from_model(self, server_optimizer):

        self.skipTest('b/184855264')

        def _model_fn_with_zero_weights():
            linear_regression_model = model_examples.LinearRegression
            weights = model_utils.ModelWeights.from_model(
                linear_regression_model)
            zero_trainable = [tf.zeros_like(x) for x in weights.trainable]
            zero_non_trainable = [
                tf.zeros_like(x) for x in weights.non_trainable
            ]
            zero_weights = model_utils.ModelWeights(
                trainable=zero_trainable, non_trainable=zero_non_trainable)
            zero_weights.assign_weights_to(linear_regression_model)
            return linear_regression_model

        def _model_fn_with_one_weights():
            linear_regression_model = model_examples.LinearRegression
            weights = model_utils.ModelWeights.from_model(
                linear_regression_model)
            ones_trainable = [tf.ones_like(x) for x in weights.trainable]
            ones_non_trainable = [
                tf.ones_like(x) for x in weights.non_trainable
            ]
            ones_weights = model_utils.ModelWeights(
                trainable=ones_trainable, non_trainable=ones_non_trainable)
            ones_weights.assign_weights_to(linear_regression_model)
            return linear_regression_model

        iterative_process_returning_zeros = optimizer_utils.build_model_delta_optimizer_process(
            model_fn=_model_fn_with_zero_weights,
            model_to_client_delta_fn=DummyClientDeltaFn,
            server_optimizer_fn=server_optimizer())

        iterative_process_returning_ones = optimizer_utils.build_model_delta_optimizer_process(
            model_fn=_model_fn_with_one_weights,
            model_to_client_delta_fn=DummyClientDeltaFn,
            server_optimizer_fn=server_optimizer())

        zero_weights_expected = iterative_process_returning_zeros.initialize(
        ).model
        one_weights_expected = iterative_process_returning_ones.initialize(
        ).model

        self.assertEqual(
            sum(tf.reduce_sum(x) for x in zero_weights_expected.trainable) +
            sum(tf.reduce_sum(x) for x in zero_weights_expected.non_trainable),
            0)
        self.assertEqual(
            sum(tf.reduce_sum(x) for x in one_weights_expected.trainable) +
            sum(tf.reduce_sum(x) for x in one_weights_expected.non_trainable),
            type_analysis.count_tensors_in_type(
                iterative_process_returning_ones.initialize.type_signature.
                result.member.model)['parameters'])
Exemple #3
0
 def test_construction_calls_model_fn(self, server_optimzier):
     # Assert that the the process building does not call `model_fn` too many
     # times. `model_fn` can potentially be expensive (loading weights,
     # processing, etc).
     mock_model_fn = mock.Mock(side_effect=model_examples.LinearRegression)
     optimizer_utils.build_model_delta_optimizer_process(
         model_fn=mock_model_fn,
         model_to_client_delta_fn=DummyClientDeltaFn,
         server_optimizer_fn=server_optimzier())
     # TODO(b/186451541): reduce the number of calls to model_fn.
     self.assertEqual(mock_model_fn.call_count, 3)
Exemple #4
0
 def test_fails_stateful_broadcast_and_process(self):
   model_weights_type = model_utils.weights_type_from_model(
       model_examples.LinearRegression)
   with self.assertRaises(optimizer_utils.DisjointArgumentError):
     optimizer_utils.build_model_delta_optimizer_process(
         model_fn=model_examples.LinearRegression,
         model_to_client_delta_fn=DummyClientDeltaFn,
         server_optimizer_fn=tf.keras.optimizers.SGD,
         stateful_model_broadcast_fn=computation_utils.StatefulBroadcastFn(
             initialize_fn=lambda: (),
             next_fn=lambda state, weights:  # pylint: disable=g-long-lambda
             (state, intrinsics.federated_broadcast(weights))),
         broadcast_process=optimizer_utils.build_stateless_broadcaster(
             model_weights_type=model_weights_type))
Exemple #5
0
 def test_fails_stateful_aggregate_and_process(self):
   model_weights_type = model_utils.weights_type_from_model(
       model_examples.LinearRegression)
   with self.assertRaises(optimizer_utils.DisjointArgumentError):
     optimizer_utils.build_model_delta_optimizer_process(
         model_fn=model_examples.LinearRegression,
         model_to_client_delta_fn=DummyClientDeltaFn,
         server_optimizer_fn=tf.keras.optimizers.SGD,
         stateful_delta_aggregate_fn=computation_utils.StatefulAggregateFn(
             initialize_fn=lambda: (),
             next_fn=lambda state, value, weight=None:  # pylint: disable=g-long-lambda
             (state, intrinsics.federated_mean(value, weight))),
         aggregation_process=optimizer_utils.build_stateless_mean(
             model_delta_type=model_weights_type.trainable))
Exemple #6
0
 def test_fails_stateful_broadcast_and_process(self):
     with tf.Graph().as_default():
         model_weights_type = tff.framework.type_from_tensors(
             model_utils.ModelWeights.from_model(
                 model_examples.LinearRegression()))
     with self.assertRaises(optimizer_utils.DisjointArgumentError):
         optimizer_utils.build_model_delta_optimizer_process(
             model_fn=model_examples.LinearRegression,
             model_to_client_delta_fn=DummyClientDeltaFn,
             server_optimizer_fn=tf.keras.optimizers.SGD,
             stateful_model_broadcast_fn=tff.utils.StatefulBroadcastFn(
                 initialize_fn=lambda: (),
                 next_fn=lambda state, weights:  # pylint: disable=g-long-lambda
                 (state, tff.federated_broadcast(weights))),
             broadcast_process=optimizer_utils.build_stateless_broadcaster(
                 model_weights_type=model_weights_type))
Exemple #7
0
    def test_iterative_process_with_encoding(self):
        model_fn = model_examples.LinearRegression
        gather_fn = encoding_utils.build_encoded_mean_from_model(
            model_fn, _test_encoder_fn('gather'))
        broadcast_fn = encoding_utils.build_encoded_broadcast_from_model(
            model_fn, _test_encoder_fn('simple'))
        iterative_process = optimizer_utils.build_model_delta_optimizer_process(
            model_fn=model_fn,
            model_to_client_delta_fn=DummyClientDeltaFn,
            server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=
                                                                1.0),
            stateful_delta_aggregate_fn=gather_fn,
            stateful_model_broadcast_fn=broadcast_fn)

        ds = tf.data.Dataset.from_tensor_slices(
            collections.OrderedDict([
                ('x', [[1.0, 2.0], [3.0, 4.0]]),
                ('y', [[5.0], [6.0]]),
            ])).batch(2)
        federated_ds = [ds] * 3

        state = iterative_process.initialize()
        self.assertEqual(state.model_broadcast_state.trainable[0][0], 1)

        state, _ = iterative_process.next(state, federated_ds)
        self.assertEqual(state.model_broadcast_state.trainable[0][0], 2)
    def test_construction_with_broadcast_process(self):
        model_weights_type = model_utils.weights_type_from_model(
            model_examples.LinearRegression)
        broadcast_process = _build_test_measured_broadcast(model_weights_type)
        iterative_process = optimizer_utils.build_model_delta_optimizer_process(
            model_fn=model_examples.LinearRegression,
            model_to_client_delta_fn=DummyClientDeltaFn,
            server_optimizer_fn=tf.keras.optimizers.SGD,
            broadcast_process=broadcast_process)

        expected_broadcast_state_type = broadcast_process.initialize.type_signature.result
        initialize_type = iterative_process.initialize.type_signature
        self.assertEqual(
            computation_types.FederatedType(
                initialize_type.result.member.model_broadcast_state,
                placements.SERVER), expected_broadcast_state_type)

        next_type = iterative_process.next.type_signature
        self.assertEqual(
            computation_types.FederatedType(
                next_type.parameter[0].member.model_broadcast_state,
                placements.SERVER), expected_broadcast_state_type)
        self.assertEqual(
            computation_types.FederatedType(
                next_type.result[0].member.model_broadcast_state,
                placements.SERVER), expected_broadcast_state_type)
    def test_orchestration_execute(self):
        iterative_process = optimizer_utils.build_model_delta_optimizer_process(
            model_fn=model_examples.TrainableLinearRegression,
            model_to_client_delta_fn=DummyClientDeltaFn,
            server_optimizer_fn=lambda: gradient_descent.SGD(learning_rate=1.0
                                                             ))

        ds = tf.data.Dataset.from_tensor_slices({
            'x': [[1., 2.], [3., 4.]],
            'y': [[5.], [6.]]
        }).batch(2)
        federated_ds = [ds] * 3

        state = iterative_process.initialize()
        self.assertSequenceAlmostEqual(state.model.trainable.a,
                                       np.zeros([2, 1], np.float32))
        self.assertAlmostEqual(state.model.trainable.b, 0.0)
        self.assertAlmostEqual(state.model.non_trainable.c, 0.0)

        state, outputs = iterative_process.next(state, federated_ds)
        self.assertSequenceAlmostEqual(state.model.trainable.a,
                                       -np.ones([2, 1], np.float32))
        self.assertAlmostEqual(state.model.trainable.b, -1.0)
        self.assertAlmostEqual(state.model.non_trainable.c, 0.0)

        # Since all predictions are 0, loss is:
        #    (0.5 * (0-5)^2 + (0-6)^2) / 2 = 15.25
        self.assertAlmostEqual(outputs.loss, 15.25, places=4)
        # 3 clients * 2 examples per client = 6 examples.
        self.assertAlmostEqual(outputs.num_examples, 6.0, places=8)
Exemple #10
0
    def test_construction_with_aggregation_process(self, server_optimizer):
        model_update_type = model_utils.weights_type_from_model(
            model_examples.LinearRegression).trainable
        model_update_aggregator = TestMeasuredMeanFactory()
        iterative_process = optimizer_utils.build_model_delta_optimizer_process(
            model_fn=model_examples.LinearRegression,
            model_to_client_delta_fn=DummyClientDeltaFn,
            server_optimizer_fn=server_optimizer(),
            model_update_aggregation_factory=model_update_aggregator)

        agg_process = model_update_aggregator.create(
            model_update_type, computation_types.TensorType(tf.float32))

        aggregation_state_type = agg_process.initialize.type_signature.result
        initialize_type = iterative_process.initialize.type_signature
        self.assertEqual(
            computation_types.FederatedType(
                initialize_type.result.member.delta_aggregate_state,
                placements.SERVER), aggregation_state_type)

        next_type = iterative_process.next.type_signature
        self.assertEqual(
            computation_types.FederatedType(
                next_type.parameter[0].member.delta_aggregate_state,
                placements.SERVER), aggregation_state_type)
        self.assertEqual(
            computation_types.FederatedType(
                next_type.result[0].member.delta_aggregate_state,
                placements.SERVER), aggregation_state_type)

        agg_metrics_type = agg_process.next.type_signature.result.measurements
        self.assertEqual(
            computation_types.FederatedType(
                next_type.result[1].member.aggregation, placements.SERVER),
            agg_metrics_type)
def build_federated_averaging_process(
        model_fn,
        server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
        client_weight_fn=None):
    """Builds the TFF computations for optimization using federated averaging.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.TrainableModel`.
    server_optimizer_fn: A no-arg function that returns a `tf.Optimizer`. The
      `apply_gradients` method of this optimizer is used to apply client updates
      to the server model. The default creates a `tf.keras.optimizers.SGD` with
      a learning rate of 1.0, which simply adds the average client delta to the
      server's model.
    client_weight_fn: Optional function that takes the output of
      `model.report_local_outputs` and returns a tensor that provides the weight
      in the federated average of model deltas. If not provided, the default is
      the total number of examples processed on device.

  Returns:
    A `tff.utils.IterativeProcess`.
  """
    def client_fed_avg(model_fn):
        return ClientFedAvg(model_fn(), client_weight_fn)

    return optimizer_utils.build_model_delta_optimizer_process(
        model_fn, client_fed_avg, server_optimizer_fn)
    def test_construction_with_aggregation_process(self):
        model_update_type = model_utils.weights_type_from_model(
            model_examples.LinearRegression).trainable
        aggregation_process = _build_test_measured_mean(model_update_type)
        iterative_process = optimizer_utils.build_model_delta_optimizer_process(
            model_fn=model_examples.LinearRegression,
            model_to_client_delta_fn=DummyClientDeltaFn,
            server_optimizer_fn=tf.keras.optimizers.SGD,
            aggregation_process=aggregation_process)

        aggregation_state_type = aggregation_process.initialize.type_signature.result
        initialize_type = iterative_process.initialize.type_signature
        self.assertEqual(
            computation_types.FederatedType(
                initialize_type.result.member.delta_aggregate_state,
                placements.SERVER), aggregation_state_type)

        next_type = iterative_process.next.type_signature
        self.assertEqual(
            computation_types.FederatedType(
                next_type.parameter[0].member.delta_aggregate_state,
                placements.SERVER), aggregation_state_type)
        self.assertEqual(
            computation_types.FederatedType(
                next_type.result[0].member.delta_aggregate_state,
                placements.SERVER), aggregation_state_type)

        aggregation_metrics_type = aggregation_process.next.type_signature.result.measurements
        self.assertEqual(
            computation_types.FederatedType(
                next_type.result[1].member.aggregation, placements.SERVER),
            aggregation_metrics_type)
Exemple #13
0
    def test_orchestration_execute_measured_process(self, server_optimizer):
        model_weights_type = model_utils.weights_type_from_model(
            model_examples.LinearRegression)
        learning_rate = 1.0
        server_optimizer_fn = server_optimizer(learning_rate)
        iterative_process = optimizer_utils.build_model_delta_optimizer_process(
            model_fn=model_examples.LinearRegression,
            model_to_client_delta_fn=DummyClientDeltaFn,
            server_optimizer_fn=server_optimizer_fn,
            broadcast_process=_build_test_measured_broadcast(
                model_weights_type),
            model_update_aggregation_factory=TestMeasuredMeanFactory())

        client_dataset = tf.data.Dataset.from_tensor_slices(
            collections.OrderedDict([
                ('x', [[1.0, 2.0], [3.0, 4.0]]),
                ('y', [[5.0], [6.0]]),
            ])).batch(2)
        federated_dataset = [client_dataset] * 3

        state = iterative_process.initialize()
        if callable(server_optimizer_fn):
            # Keras SGD keeps track of a single scalar for the number of iterations.
            self.assertAllEqual(state.optimizer_state, [0])
        else:
            # TFF SGD stores learning rate in state.
            self.assertAllClose(
                state.optimizer_state,
                collections.OrderedDict([(optimizer.LEARNING_RATE_KEY,
                                          learning_rate)]))
        self.assertAllClose(list(state.model.trainable),
                            [np.zeros((2, 1)), 0.0])
        self.assertAllClose(list(state.model.non_trainable), [0.0])
        self.assertEqual(state.delta_aggregate_state, 0)
        self.assertEqual(state.model_broadcast_state, 0)

        state, outputs = iterative_process.next(state, federated_dataset)
        self.assertAllClose(
            # `DummyClientDeltaFn` always sends fake model weights deltas (negative
            # ones) back. Because the initial model weights are all zeros, the
            # updated model weights will be all negative ones.
            list(state.model.trainable),
            [-np.ones((2, 1)), -1.0])
        self.assertAllClose(list(state.model.non_trainable), [0.0])
        if callable(server_optimizer_fn):
            self.assertAllEqual(state.optimizer_state, [1])
        self.assertEqual(state.delta_aggregate_state, 1)
        self.assertEqual(state.model_broadcast_state, 1)

        expected_outputs = collections.OrderedDict(
            # `_build_test_measured_broadcast` builds a broadcast process whose
            # `measurements` is 3.0 + norm of initial model weights (which is zero).
            broadcast=3.0,
            aggregation=collections.OrderedDict(num_clients=3),
            train=collections.OrderedDict(
                # The average mean squared loss is computed at the initial model
                # weights (i.e., at zero weights): 0.5*(25+36)/2 = 15.25.
                loss=15.25,
                num_examples=6))
        self.assertAllEqual(expected_outputs, outputs)
Exemple #14
0
def build_federated_averaging_process(
    model_fn: Callable[[], model_lib.Model],
    client_optimizer_fn: Callable[[], tf.keras.optimizers.Optimizer],
    server_optimizer_fn: Callable[
        [], tf.keras.optimizers.Optimizer] = DEFAULT_SERVER_OPTIMIZER_FN,
    client_weight_fn: Callable[[Any], tf.Tensor] = None,
    stateful_delta_aggregate_fn=None,
    stateful_model_broadcast_fn=None) -> tff.utils.IterativeProcess:
  """Builds the TFF computations for optimization using federated averaging.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    client_optimizer_fn: A no-arg callable that returns a `tf.keras.Optimizer`.
    server_optimizer_fn: A no-arg callable that returns a `tf.keras.Optimizer`.
      The `apply_gradients` method of this optimizer is used to apply client
      updates to the server model. The default creates a
      `tf.keras.optimizers.SGD` with a learning rate of 1.0, which simply adds
      the average client delta to the server's model.
    client_weight_fn: Optional function that takes the output of
      `model.report_local_outputs` and returns a tensor that provides the weight
      in the federated average of model deltas. If not provided, the default is
      the total number of examples processed on device.
    stateful_delta_aggregate_fn: A `tff.utils.StatefulAggregateFn` where the
      `next_fn` performs a federated aggregation and upates state. That is, it
      has TFF type `(state@SERVER, value@CLIENTS, weights@CLIENTS) ->
      (state@SERVER, aggregate@SERVER)`, where the `value` type is
      `tff.learning.framework.ModelWeights.trainable` corresponding to the
      object returned by `model_fn`. By default performs arithmetic mean
      aggregation, weighted by `client_weight_fn`.
    stateful_model_broadcast_fn: A `tff.utils.StatefulBroadcastFn` where the
      `next_fn` performs a federated broadcast and upates state. That is, it has
      TFF type `(state@SERVER, value@SERVER) -> (state@SERVER, value@CLIENTS)`,
      where the `value` type is `tff.learning.framework.ModelWeights`
      corresponding to the object returned by `model_fn`. By default performs
      identity broadcast.

  Returns:
    A `tff.utils.IterativeProcess`.
  """

  def client_fed_avg(model_fn):
    return _ClientFedAvg(model_fn(), client_optimizer_fn(), client_weight_fn)

  if stateful_delta_aggregate_fn is None:
    stateful_delta_aggregate_fn = optimizer_utils.build_stateless_mean()
  else:
    py_typecheck.check_type(stateful_delta_aggregate_fn,
                            tff.utils.StatefulAggregateFn)

  if stateful_model_broadcast_fn is None:
    stateful_model_broadcast_fn = optimizer_utils.build_stateless_broadcaster()
  else:
    py_typecheck.check_type(stateful_model_broadcast_fn,
                            tff.utils.StatefulBroadcastFn)

  return optimizer_utils.build_model_delta_optimizer_process(
      model_fn, client_fed_avg, server_optimizer_fn,
      stateful_delta_aggregate_fn, stateful_model_broadcast_fn)
    def test_construction(self, weighted):
        aggregation_factory = (mean.MeanFactory()
                               if weighted else sum_factory.SumFactory())
        iterative_process = optimizer_utils.build_model_delta_optimizer_process(
            model_fn=model_examples.LinearRegression,
            model_to_client_delta_fn=DummyClientDeltaFn,
            server_optimizer_fn=tf.keras.optimizers.SGD,
            model_update_aggregation_factory=aggregation_factory)

        if weighted:
            aggregate_state = collections.OrderedDict(value_sum_process=(),
                                                      weight_sum_process=())
            aggregate_metrics = collections.OrderedDict(mean_value=(),
                                                        mean_weight=())
        else:
            aggregate_state = ()
            aggregate_metrics = ()

        server_state_type = computation_types.FederatedType(
            optimizer_utils.ServerState(model=model_utils.ModelWeights(
                trainable=[
                    computation_types.TensorType(tf.float32, [2, 1]),
                    computation_types.TensorType(tf.float32)
                ],
                non_trainable=[computation_types.TensorType(tf.float32)]),
                                        optimizer_state=[tf.int64],
                                        delta_aggregate_state=aggregate_state,
                                        model_broadcast_state=()),
            placements.SERVER)
        self.assert_types_equivalent(
            computation_types.FunctionType(parameter=None,
                                           result=server_state_type),
            iterative_process.initialize.type_signature)

        dataset_type = computation_types.FederatedType(
            computation_types.SequenceType(
                collections.OrderedDict(
                    x=computation_types.TensorType(tf.float32, [None, 2]),
                    y=computation_types.TensorType(tf.float32, [None, 1]))),
            placements.CLIENTS)
        metrics_type = computation_types.FederatedType(
            collections.OrderedDict(
                broadcast=(),
                aggregation=aggregate_metrics,
                train=collections.OrderedDict(
                    loss=computation_types.TensorType(tf.float32),
                    num_examples=computation_types.TensorType(tf.int32)),
                stat=collections.OrderedDict(
                    num_examples=computation_types.TensorType(tf.float32))),
            placements.SERVER)
        self.assert_types_equivalent(
            computation_types.FunctionType(parameter=collections.OrderedDict(
                server_state=server_state_type,
                federated_dataset=dataset_type,
            ),
                                           result=(server_state_type,
                                                   metrics_type)),
            iterative_process.next.type_signature)
Exemple #16
0
 def test_iterative_process_with_encoding(self):
   model_fn = model_examples.LinearRegression
   gather_fn = encoding_utils.build_encoded_mean_from_model(
       model_fn, _test_encoder_fn('gather'))
   broadcast_fn = encoding_utils.build_encoded_broadcast_from_model(
       model_fn, _test_encoder_fn('simple'))
   iterative_process = optimizer_utils.build_model_delta_optimizer_process(
       model_fn=model_fn,
       model_to_client_delta_fn=DummyClientDeltaFn,
       server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
       stateful_delta_aggregate_fn=gather_fn,
       stateful_model_broadcast_fn=broadcast_fn)
   self._verify_iterative_process(iterative_process)
    def test_construction_with_adam_optimizer(self):
        iterative_process = optimizer_utils.build_model_delta_optimizer_process(
            model_fn=model_examples.LinearRegression,
            model_to_client_delta_fn=DummyClientDeltaFn,
            server_optimizer_fn=tf.keras.optimizers.Adam)
        # Assert that the optimizer_state includes the 5 variables (scalar for
        # # of iterations, plus two copies of the kernel and bias in the model).
        initialize_type = iterative_process.initialize.type_signature
        self.assertLen(initialize_type.result.member.optimizer_state, 5)

        next_type = iterative_process.next.type_signature
        self.assertLen(next_type.parameter[0].member.optimizer_state, 5)
        self.assertLen(next_type.result[0].member.optimizer_state, 5)
    def test_orchestration_execute_measured_process(self, server_optimizer):
        model_weights_type = model_utils.weights_type_from_model(
            model_examples.LinearRegression)
        learning_rate = 1.0
        server_optimizer_fn = server_optimizer(learning_rate)
        iterative_process = optimizer_utils.build_model_delta_optimizer_process(
            model_fn=model_examples.LinearRegression,
            model_to_client_delta_fn=DummyClientDeltaFn,
            server_optimizer_fn=server_optimizer_fn,
            broadcast_process=_build_test_measured_broadcast(
                model_weights_type),
            model_update_aggregation_factory=TestMeasuredMeanFactory())

        ds = tf.data.Dataset.from_tensor_slices(
            collections.OrderedDict([
                ('x', [[1.0, 2.0], [3.0, 4.0]]),
                ('y', [[5.0], [6.0]]),
            ])).batch(2)
        federated_ds = [ds] * 3

        state = iterative_process.initialize()
        if callable(server_optimizer_fn):
            # Keras SGD keeps track of a single scalar for the number of iterations.
            self.assertAllEqual(state.optimizer_state, [0])
        else:
            # TFF SGD has an empty state.
            self.assertEmpty(state.optimizer_state)
        self.assertAllClose(list(state.model.trainable),
                            [np.zeros((2, 1)), 0.0])
        self.assertAllClose(list(state.model.non_trainable), [0.0])
        self.assertEqual(state.delta_aggregate_state, 0)
        self.assertEqual(state.model_broadcast_state, 0)

        state, outputs = iterative_process.next(state, federated_ds)
        self.assertAllClose(list(state.model.trainable),
                            [-np.ones((2, 1)), -1.0 * learning_rate])
        self.assertAllClose(list(state.model.non_trainable), [0.0])
        if callable(server_optimizer_fn):
            self.assertAllEqual(state.optimizer_state, [1])
        self.assertEqual(state.delta_aggregate_state, 1)
        self.assertEqual(state.model_broadcast_state, 1)

        expected_outputs = collections.OrderedDict(
            broadcast=3.0,
            aggregation=collections.OrderedDict(num_clients=3),
            train={
                'loss': 15.25,
                'num_examples': 6,
            },
            stat=collections.OrderedDict(num_examples=3.0))
        self.assertAllEqual(expected_outputs, outputs)
    def test_orchestration_type_signature(self):
        iterative_process = optimizer_utils.build_model_delta_optimizer_process(
            model_fn=model_examples.TrainableLinearRegression,
            model_to_client_delta_fn=DummyClientDeltaFn,
            server_optimizer_fn=lambda: gradient_descent.SGD(learning_rate=1.0
                                                             ))

        expected_model_weights_type = model_utils.ModelWeights(
            collections.OrderedDict([('a', tff.TensorType(tf.float32, [2, 1])),
                                     ('b', tf.float32)]),
            collections.OrderedDict([('c', tf.float32)]))

        # ServerState consists of a model and optimizer_state. The optimizer_state
        # is provided by TensorFlow, TFF doesn't care what the actual value is.
        expected_federated_server_state_type = tff.FederatedType(
            optimizer_utils.ServerState(expected_model_weights_type,
                                        test.AnyType(), test.AnyType(),
                                        test.AnyType()),
            placement=tff.SERVER,
            all_equal=True)

        expected_federated_dataset_type = tff.FederatedType(tff.SequenceType(
            model_examples.TrainableLinearRegression().input_spec),
                                                            tff.CLIENTS,
                                                            all_equal=False)

        expected_model_output_types = tff.FederatedType(
            collections.OrderedDict([
                ('loss', tff.TensorType(tf.float32, [])),
                ('num_examples', tff.TensorType(tf.int32, [])),
            ]),
            tff.SERVER,
            all_equal=True)

        # `initialize` is expected to be a funcion of no arguments to a ServerState.
        self.assertEqual(
            tff.FunctionType(parameter=None,
                             result=expected_federated_server_state_type),
            iterative_process.initialize.type_signature)

        # `next` is expected be a function of (ServerState, Datasets) to
        # ServerState.
        self.assertEqual(
            tff.FunctionType(parameter=[
                expected_federated_server_state_type,
                expected_federated_dataset_type
            ],
                             result=(expected_federated_server_state_type,
                                     expected_model_output_types)),
            iterative_process.next.type_signature)
    def test_orchestration_execute_measured_process(self):
        with tf.Graph().as_default():
            model_weights_type = tff.framework.type_from_tensors(
                model_utils.ModelWeights.from_model(
                    model_examples.LinearRegression()))
        learning_rate = 1.0
        iterative_process = optimizer_utils.build_model_delta_optimizer_process(
            model_fn=model_examples.LinearRegression,
            model_to_client_delta_fn=DummyClientDeltaFn,
            server_optimizer_fn=functools.partial(tf.keras.optimizers.SGD,
                                                  learning_rate=learning_rate),
            broadcast_process=_build_test_measured_broadcast(
                model_weights_type),
            aggregation_process=_build_test_measured_mean(
                model_weights_type.trainable))

        ds = tf.data.Dataset.from_tensor_slices(
            collections.OrderedDict([
                ('x', [[1.0, 2.0], [3.0, 4.0]]),
                ('y', [[5.0], [6.0]]),
            ])).batch(2)
        federated_ds = [ds] * 3

        state = iterative_process.initialize()
        # SGD keeps track of a single scalar for the number of iterations.
        self.assertAllEqual(state.optimizer_state, [0])
        self.assertAllClose(list(state.model.trainable),
                            [np.zeros((2, 1)), 0.0])
        self.assertAllClose(list(state.model.non_trainable), [0.0])
        self.assertEqual(state.delta_aggregate_state, 0)
        self.assertEqual(state.model_broadcast_state, 0)

        state, outputs = iterative_process.next(state, federated_ds)
        self.assertAllClose(list(state.model.trainable),
                            [-np.ones((2, 1)), -1.0 * learning_rate])
        self.assertAllClose(list(state.model.non_trainable), [0.0])
        self.assertAllEqual(state.optimizer_state, [1])
        self.assertEqual(state.delta_aggregate_state, 1)
        self.assertEqual(state.model_broadcast_state, 1)

        expected_outputs = anonymous_tuple.from_container(
            collections.OrderedDict(
                broadcast=3.0,
                aggregation=collections.OrderedDict(num_clients=3),
                train=collections.OrderedDict(
                    loss=15.25,
                    num_examples=6,
                )),
            recursive=True)
        self.assertEqual(expected_outputs, outputs)
    def test_orchestration_execute_stateful_fn(self):
        learning_rate = 1.0
        iterative_process = optimizer_utils.build_model_delta_optimizer_process(
            model_fn=model_examples.LinearRegression,
            model_to_client_delta_fn=DummyClientDeltaFn,
            server_optimizer_fn=functools.partial(tf.keras.optimizers.SGD,
                                                  learning_rate=learning_rate),
            # A federated_mean that maintains an int32 state equal to the
            # number of times the federated_mean has been executed,
            # allowing us to test that a stateful aggregator's state
            # is properly updated.
            stateful_delta_aggregate_fn=state_incrementing_mean,
            # Similarly, a broadcast with state that increments:
            stateful_model_broadcast_fn=state_incrementing_broadcaster)

        ds = tf.data.Dataset.from_tensor_slices(
            collections.OrderedDict([
                ('x', [[1.0, 2.0], [3.0, 4.0]]),
                ('y', [[5.0], [6.0]]),
            ])).batch(2)
        federated_ds = [ds] * 3

        state = iterative_process.initialize()
        # SGD keeps track of a single scalar for the number of iterations.
        self.assertAllEqual(state.optimizer_state, [0])
        self.assertAllClose(list(state.model.trainable),
                            [np.zeros((2, 1)), 0.0])
        self.assertAllClose(list(state.model.non_trainable), [0.0])
        self.assertEqual(state.delta_aggregate_state, 0)
        self.assertEqual(state.model_broadcast_state, 0)

        state, outputs = iterative_process.next(state, federated_ds)
        self.assertAllClose(list(state.model.trainable),
                            [-np.ones((2, 1)), -1.0 * learning_rate])
        self.assertAllClose(list(state.model.non_trainable), [0.0])
        self.assertAllEqual(state.optimizer_state, [1])
        self.assertEqual(state.delta_aggregate_state, 1)
        self.assertEqual(state.model_broadcast_state, 1)

        expected_outputs = anonymous_tuple.from_container(
            collections.OrderedDict(
                # The StatefulFn output no metrics.
                broadcast=(),
                aggregation=(),
                train=collections.OrderedDict(
                    loss=15.25,
                    num_examples=6,
                )),
            recursive=True)
        self.assertEqual(expected_outputs, outputs)
  def test_construction(self):
    iterative_process = optimizer_utils.build_model_delta_optimizer_process(
        model_fn=model_examples.LinearRegression,
        model_to_client_delta_fn=DummyClientDeltaFn,
        server_optimizer_fn=tf.keras.optimizers.SGD)

    server_state_type = computation_types.FederatedType(
        optimizer_utils.ServerState(
            model=model_utils.ModelWeights(
                trainable=[
                    computation_types.TensorType(tf.float32, [2, 1]),
                    computation_types.TensorType(tf.float32)
                ],
                non_trainable=[computation_types.TensorType(tf.float32)]),
            optimizer_state=[tf.int64],
            delta_aggregate_state=(),
            model_broadcast_state=()), placements.SERVER)

    self.assertEqual(
        str(iterative_process.initialize.type_signature),
        str(
            computation_types.FunctionType(
                parameter=None, result=server_state_type)))

    dataset_type = computation_types.FederatedType(
        computation_types.SequenceType(
            collections.OrderedDict(
                x=computation_types.TensorType(tf.float32, [None, 2]),
                y=computation_types.TensorType(tf.float32, [None, 1]))),
        placements.CLIENTS)

    metrics_type = computation_types.FederatedType(
        collections.OrderedDict(
            broadcast=(),
            aggregation=(),
            train=collections.OrderedDict(
                loss=computation_types.TensorType(tf.float32),
                num_examples=computation_types.TensorType(tf.int32))),
        placements.SERVER)

    self.assertEqual(
        str(iterative_process.next.type_signature),
        str(
            computation_types.FunctionType(
                parameter=collections.OrderedDict(
                    server_state=server_state_type,
                    federated_dataset=dataset_type,
                ),
                result=(server_state_type, metrics_type))))
Exemple #23
0
    def test_construction_with_tff_sgdm_optimizer(self, momentum, state_len):
        iterative_process = optimizer_utils.build_model_delta_optimizer_process(
            model_fn=model_examples.LinearRegression,
            model_to_client_delta_fn=DummyClientDeltaFn,
            server_optimizer_fn=sgdm.build_sgdm(learning_rate=0.1,
                                                momentum=momentum))
        # Assert that the optimizer_state is empty for SGD without momentum;
        # and includes two tensors for the momentum: kernel and bias from the model.
        initialize_type = iterative_process.initialize.type_signature
        self.assertLen(initialize_type.result.member.optimizer_state,
                       state_len)

        next_type = iterative_process.next.type_signature
        self.assertLen(next_type.parameter[0].member.optimizer_state,
                       state_len)
        self.assertLen(next_type.result[0].member.optimizer_state, state_len)
    def test_contruction_with_broadcast_state(self):
        iterative_process = optimizer_utils.build_model_delta_optimizer_process(
            model_fn=model_examples.LinearRegression,
            model_to_client_delta_fn=DummyClientDeltaFn,
            server_optimizer_fn=tf.keras.optimizers.SGD,
            stateful_model_broadcast_fn=state_incrementing_broadcaster)

        expected_broadcast_state_type = tff.TensorType(tf.int32)

        initialize_type = iterative_process.initialize.type_signature
        self.assertEqual(initialize_type.result.member.model_broadcast_state,
                         expected_broadcast_state_type)

        next_type = iterative_process.next.type_signature
        self.assertEqual(next_type.parameter[0].member.model_broadcast_state,
                         expected_broadcast_state_type)
        self.assertEqual(next_type.result[0].member.model_broadcast_state,
                         expected_broadcast_state_type)
    def test_orchestration_execute_sgd(self):
        learning_rate = 1.0
        iterative_process = optimizer_utils.build_model_delta_optimizer_process(
            model_fn=model_examples.LinearRegression,
            model_to_client_delta_fn=DummyClientDeltaFn,
            server_optimizer_fn=functools.partial(tf.keras.optimizers.SGD,
                                                  learning_rate=learning_rate),
            # A federated_mean that maintains an int32 state equal to the
            # number of times the federated_mean has been executed,
            # allowing us to test that a stateful aggregator's state
            # is properly updated.
            stateful_delta_aggregate_fn=state_incrementing_mean,
            # Similarly, a broadcast with state that increments:
            stateful_model_broadcast_fn=state_incrementing_broadcaster)

        ds = tf.data.Dataset.from_tensor_slices(
            collections.OrderedDict([
                ('x', [[1.0, 2.0], [3.0, 4.0]]),
                ('y', [[5.0], [6.0]]),
            ])).batch(2)
        federated_ds = [ds] * 3

        state = iterative_process.initialize()
        # SGD keeps track of a single scalar for the number of iterations.
        self.assertAllEqual(state.optimizer_state, [0])
        self.assertAllClose(list(state.model.trainable),
                            [np.zeros((2, 1)), 0.0])
        self.assertAllClose(list(state.model.non_trainable), [0.0])
        self.assertEqual(state.delta_aggregate_state, 0)
        self.assertEqual(state.model_broadcast_state, 0)

        state, outputs = iterative_process.next(state, federated_ds)
        self.assertAllClose(list(state.model.trainable),
                            [-np.ones((2, 1)), -1.0 * learning_rate])
        self.assertAllClose(list(state.model.non_trainable), [0.0])
        self.assertAllEqual(state.optimizer_state, [1])
        self.assertEqual(state.delta_aggregate_state, 1)
        self.assertEqual(state.model_broadcast_state, 1)

        # Since all predictions are 0, loss is:
        #    (0.5 * (0-5)^2 + (0-6)^2) / 2 = 15.25
        self.assertAlmostEqual(outputs.loss, 15.25, places=4)
        # 3 clients * 2 examples per client = 6 examples.
        self.assertAlmostEqual(outputs.num_examples, 6.0, places=8)
    def test_orchestration_execute(self):
        iterative_process = optimizer_utils.build_model_delta_optimizer_process(
            model_fn=model_examples.TrainableLinearRegression,
            model_to_client_delta_fn=DummyClientDeltaFn,
            server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=
                                                                1.0),
            # A federated_mean that maintains an int32 state equal to the
            # number of times the federated_mean has been executed,
            # allowing us to test that a stateful aggregator's state
            # is properly updated.
            stateful_delta_aggregate_fn=state_incrementing_mean,
            # Similarly, a broadcast with state that increments:
            stateful_model_broadcast_fn=state_incrementing_broadcaster)

        ds = tf.data.Dataset.from_tensor_slices(
            collections.OrderedDict([
                ('x', [[1.0, 2.0], [3.0, 4.0]]),
                ('y', [[5.0], [6.0]]),
            ])).batch(2)
        federated_ds = [ds] * 3

        state = iterative_process.initialize()
        self.assertSequenceAlmostEqual(state.model.trainable.a,
                                       np.zeros([2, 1], np.float32))
        self.assertAlmostEqual(state.model.trainable.b, 0.0)
        self.assertAlmostEqual(state.model.non_trainable.c, 0.0)
        self.assertEqual(state.delta_aggregate_state, 0)
        self.assertEqual(state.model_broadcast_state, 0)

        state, outputs = iterative_process.next(state, federated_ds)
        self.assertSequenceAlmostEqual(state.model.trainable.a,
                                       -np.ones([2, 1], np.float32))
        self.assertAlmostEqual(state.model.trainable.b, -1.0)
        self.assertAlmostEqual(state.model.non_trainable.c, 0.0)
        self.assertEqual(state.delta_aggregate_state, 1)
        self.assertEqual(state.model_broadcast_state, 1)

        # Since all predictions are 0, loss is:
        #    (0.5 * (0-5)^2 + (0-6)^2) / 2 = 15.25
        self.assertAlmostEqual(outputs.loss, 15.25, places=4)
        # 3 clients * 2 examples per client = 6 examples.
        self.assertAlmostEqual(outputs.num_examples, 6.0, places=8)
Exemple #27
0
  def test_iterative_process_with_encoding(self):
    model_fn = model_examples.TrainableLinearRegression
    broadcast_fn = encoding_utils.build_encoded_broadcast_from_model(
        model_fn, _test_encoder_fn())
    iterative_process = optimizer_utils.build_model_delta_optimizer_process(
        model_fn=model_fn,
        model_to_client_delta_fn=DummyClientDeltaFn,
        server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
        stateful_model_broadcast_fn=broadcast_fn)

    ds = tf.data.Dataset.from_tensor_slices({
        'x': [[1., 2.], [3., 4.]],
        'y': [[5.], [6.]]
    }).batch(2)
    federated_ds = [ds] * 3

    state = iterative_process.initialize()
    self.assertEqual(state.model_broadcast_state.trainable.a[0], 1)

    state, _ = iterative_process.next(state, federated_ds)
    self.assertEqual(state.model_broadcast_state.trainable.a[0], 2)
def build_federated_process_for_test(model_fn, num_passes=5, tolerance=1e-6):
    """Build a test FedAvg process with a dummy client computation.

  Analogue of `build_federated_averaging_process`, but with client_fed_avg
  replaced by the dummy mean computation defined above.

  Args:
    model_fn: callable that returns a `tff.learning.Model`.
    num_passes: integer number  of communication rounds in the smoothed
      Weiszfeld algorithm (min. 1).
    tolerance: float smoothing parameter of smoothed Weiszfeld algorithm.
      Default 1e-6.

  Returns:
    A `tff.utils.IterativeProcess`.
  """

    server_optimizer_fn = lambda: tf.keras.optimizers.SGD(learning_rate=1.0)

    def client_fed_avg(model_fn):
        return DummyClientComputation(model_fn(), client_weight_fn=None)

    # Build robust aggregation function
    with tf.Graph().as_default():
        # workaround since keras automatically appends "_n" to the nth call of
        # `model_fn`
        model_type = tff.framework.type_from_tensors(
            model_fn().weights.trainable)

        stateful_delta_aggregate_fn = rfa.build_stateless_robust_aggregation(
            model_type,
            num_communication_passes=num_passes,
            tolerance=tolerance)

        stateful_model_broadcast_fn = optimizer_utils.build_stateless_broadcaster(
        )

        return optimizer_utils.build_model_delta_optimizer_process(
            model_fn, client_fed_avg, server_optimizer_fn,
            stateful_delta_aggregate_fn, stateful_model_broadcast_fn)
Exemple #29
0
def build_federated_sgd_process(
    model_fn: Callable[[], model_lib.Model],
    server_optimizer_fn: Callable[
        [], tf.keras.optimizers.Optimizer] = DEFAULT_SERVER_OPTIMIZER_FN,
    client_weight_fn: Callable[[Any], tf.Tensor] = None,
    *,
    broadcast_process: Optional[measured_process.MeasuredProcess] = None,
    aggregation_process: Optional[measured_process.MeasuredProcess] = None,
    model_update_aggregation_factory: Optional[
        factory.AggregationProcessFactory] = None,
    use_experimental_simulation_loop: bool = False,
) -> iterative_process.IterativeProcess:
    """Builds the TFF computations for optimization using federated SGD.

  This function creates a `tff.templates.IterativeProcess` that performs
  federated averaging on client models. The iterative process has the following
  methods:

  *   `initialize`: A `tff.Computation` with the functional type signature
      `( -> S@SERVER)`, where `S` is a `tff.learning.framework.ServerState`
      representing the initial state of the server.
  *   `next`: A `tff.Computation` with the functional type signature
      `(<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>)` where `S` is a
      `tff.learning.framework.ServerState` whose type matches that of the output
      of `initialize`, and `{B*}@CLIENTS` represents the client datasets, where
      `B` is the type of a single batch. This computation returns a
      `tff.learning.framework.ServerState` representing the updated server state
      and metrics that are the result of
      `tff.learning.Model.federated_output_computation` during client training
      and any other metrics from broadcast and aggregation processes.

  Each time the `next` method is called, the server model is broadcast to each
  client using a broadcast function. Each client sums the gradients at each
  batch in the client's local dataset. These gradient sums are then aggregated
  at the server using an aggregation function. The aggregate gradients are
  applied at the server by using the
  `tf.keras.optimizers.Optimizer.apply_gradients` method of the server
  optimizer.

  This implements the original FedSGD algorithm in [McMahan et al.,
  2017](https://arxiv.org/abs/1602.05629).

  Note: the default server optimizer function is `tf.keras.optimizers.SGD`
  with a learning rate of 0.1. More sophisticated federated SGD procedures may
  use different learning rates or server optimizers.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    server_optimizer_fn: A no-arg function that returns a `tf.Optimizer`. The
      `apply_gradients` method of this optimizer is used to apply client updates
      to the server model.
    client_weight_fn: Optional function that takes the output of
      `model.report_local_outputs` and returns a tensor that provides the weight
      in the federated average of the aggregated gradients. If not provided, the
      default is the total number of examples processed on device.
    broadcast_process: a `tff.templates.MeasuredProcess` that broadcasts the
      model weights on the server to the clients. It must support the signature
      `(input_values@SERVER -> output_values@CLIENT)`.
    aggregation_process: a `tff.templates.MeasuredProcess` that aggregates the
      model updates on the clients back to the server. It must support the
      signature `({input_values}@CLIENTS-> output_values@SERVER)`. Must be
      `None` if `model_update_aggregation_factory` is not `None.`
    model_update_aggregation_factory: An optional
      `tff.aggregators.AggregationProcessFactory` that constructs
      `tff.templates.AggregationProcess` for aggregating the client model
      updates on the server. If `None`, uses a default constructed
      `tff.aggregators.MeanFactory`, creating a stateless mean aggregation. Must
      be `None` if `aggregation_process` is not `None.`
    use_experimental_simulation_loop: Controls the reduce loop function for
        input dataset. An experimental reduce loop is used for simulation.

  Returns:
    A `tff.templates.IterativeProcess`.
  """
    def client_sgd_avg(model_fn: Callable[[], model_lib.Model]) -> ClientSgd:
        return ClientSgd(
            model_fn(),
            client_weight_fn,
            use_experimental_simulation_loop=use_experimental_simulation_loop)

    return optimizer_utils.build_model_delta_optimizer_process(
        model_fn,
        model_to_client_delta_fn=client_sgd_avg,
        server_optimizer_fn=server_optimizer_fn,
        broadcast_process=broadcast_process,
        aggregation_process=aggregation_process,
        model_update_aggregation_factory=model_update_aggregation_factory)
Exemple #30
0
    def test_execute_measured_process_with_custom_metrics_aggregator(
            self, server_optimizer):
        model_weights_type = model_utils.weights_type_from_model(
            model_examples.LinearRegression)
        learning_rate = 1.0
        server_optimizer_fn = server_optimizer(learning_rate)

        def custom_metrics_aggregator(metric_finalizers,
                                      local_unfinalized_metrics_type):
            """Builds a TFF computation that computes per-client min/max metrics."""
            @tensorflow_computation.tf_computation(
                local_unfinalized_metrics_type)
            def finalizer_computation(unfinalized_metrics):
                finalized_metrics = collections.OrderedDict()
                for metric_name, finalizer in metric_finalizers.items():
                    finalized_metrics[metric_name] = finalizer(
                        unfinalized_metrics[metric_name])
                return finalized_metrics

            @federated_computation.federated_computation(
                computation_types.at_clients(local_unfinalized_metrics_type))
            def aggregator_computation(client_local_unfinalized_metrics):
                client_local_finalized_metrics = intrinsics.federated_map(
                    finalizer_computation, client_local_unfinalized_metrics)
                aggregated_metrics = collections.OrderedDict(
                    loss_per_client_max=primitives.federated_max(
                        client_local_finalized_metrics['loss']),
                    loss_per_client_min=primitives.federated_min(
                        client_local_finalized_metrics['loss']),
                    num_examples_per_client_max=primitives.federated_max(
                        client_local_finalized_metrics['num_examples']),
                    num_examples_per_client_min=primitives.federated_min(
                        client_local_finalized_metrics['num_examples']))
                return intrinsics.federated_zip(aggregated_metrics)

            return aggregator_computation

        iterative_process = optimizer_utils.build_model_delta_optimizer_process(
            model_fn=model_examples.LinearRegression,
            model_to_client_delta_fn=DummyClientDeltaFn,
            server_optimizer_fn=server_optimizer_fn,
            broadcast_process=_build_test_measured_broadcast(
                model_weights_type),
            model_update_aggregation_factory=TestMeasuredMeanFactory(),
            metrics_aggregator=custom_metrics_aggregator)

        # The first client has 1 example. The second client has 2 examples.
        client_1_dataset = tf.data.Dataset.from_tensor_slices(
            collections.OrderedDict([
                ('x', [[-1.0, -1.0]]),
                ('y', [[1.0]]),
            ])).batch(1)
        client_2_dataset = tf.data.Dataset.from_tensor_slices(
            collections.OrderedDict([
                ('x', [[1.0, 2.0], [3.0, 4.0]]),
                ('y', [[5.0], [6.0]]),
            ])).batch(2)
        federated_dataset = [client_1_dataset, client_2_dataset]

        state = iterative_process.initialize()
        if callable(server_optimizer_fn):
            # Keras SGD keeps track of a single scalar for the number of iterations.
            self.assertAllEqual(state.optimizer_state, [0])
        else:
            # TFF SGD stores learning rate in state.
            self.assertAllClose(
                state.optimizer_state,
                collections.OrderedDict([(optimizer.LEARNING_RATE_KEY,
                                          learning_rate)]))
        self.assertAllClose(list(state.model.trainable),
                            [np.zeros((2, 1)), 0.0])
        self.assertAllClose(list(state.model.non_trainable), [0.0])
        self.assertEqual(state.delta_aggregate_state, 0)
        self.assertEqual(state.model_broadcast_state, 0)

        state, outputs = iterative_process.next(state, federated_dataset)
        self.assertAllClose(
            # `DummyClientDeltaFn` always sends fake model weights deltas (negative
            # ones) back. Because the initial model weights are all zeros, the
            # updated model weights will be all negative ones.
            list(state.model.trainable),
            [-np.ones((2, 1)), -1.0])
        self.assertAllClose(list(state.model.non_trainable), [0.0])
        if callable(server_optimizer_fn):
            self.assertAllEqual(state.optimizer_state, [1])
        self.assertEqual(state.delta_aggregate_state, 1)
        self.assertEqual(state.model_broadcast_state, 1)

        expected_outputs = collections.OrderedDict(
            # `_build_test_measured_broadcast` builds a broadcast process whose
            # `measurements` is 3.0 + norm of initial model weights (which is zero).
            broadcast=3.0,
            aggregation=collections.OrderedDict(num_clients=2),
            train=collections.OrderedDict(
                # The average mean squared loss is computed at the initial model
                # weights (i.e., at zeros weights): the loss is 0.5 for the first
                # client and is 0.5*(25+36)/2 = 15.25 for the second client.
                loss_per_client_max=15.25,
                loss_per_client_min=0.5,
                num_examples_per_client_max=2,
                num_examples_per_client_min=1))
        self.assertAllClose(expected_outputs, outputs)