def test_construction_fails_with_invalid_aggregation_factory(self): aggregation_factory = sampling.UnweightedReservoirSamplingFactory( sample_size=1) with self.assertRaisesRegex( TypeError, 'does not produce a compatible `AggregationProcess`'): optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.LinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=tf.keras.optimizers.SGD, model_update_aggregation_factory=aggregation_factory)
def test_initial_weights_pulled_from_model(self, server_optimizer): self.skipTest('b/184855264') def _model_fn_with_zero_weights(): linear_regression_model = model_examples.LinearRegression weights = model_utils.ModelWeights.from_model( linear_regression_model) zero_trainable = [tf.zeros_like(x) for x in weights.trainable] zero_non_trainable = [ tf.zeros_like(x) for x in weights.non_trainable ] zero_weights = model_utils.ModelWeights( trainable=zero_trainable, non_trainable=zero_non_trainable) zero_weights.assign_weights_to(linear_regression_model) return linear_regression_model def _model_fn_with_one_weights(): linear_regression_model = model_examples.LinearRegression weights = model_utils.ModelWeights.from_model( linear_regression_model) ones_trainable = [tf.ones_like(x) for x in weights.trainable] ones_non_trainable = [ tf.ones_like(x) for x in weights.non_trainable ] ones_weights = model_utils.ModelWeights( trainable=ones_trainable, non_trainable=ones_non_trainable) ones_weights.assign_weights_to(linear_regression_model) return linear_regression_model iterative_process_returning_zeros = optimizer_utils.build_model_delta_optimizer_process( model_fn=_model_fn_with_zero_weights, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=server_optimizer()) iterative_process_returning_ones = optimizer_utils.build_model_delta_optimizer_process( model_fn=_model_fn_with_one_weights, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=server_optimizer()) zero_weights_expected = iterative_process_returning_zeros.initialize( ).model one_weights_expected = iterative_process_returning_ones.initialize( ).model self.assertEqual( sum(tf.reduce_sum(x) for x in zero_weights_expected.trainable) + sum(tf.reduce_sum(x) for x in zero_weights_expected.non_trainable), 0) self.assertEqual( sum(tf.reduce_sum(x) for x in one_weights_expected.trainable) + sum(tf.reduce_sum(x) for x in one_weights_expected.non_trainable), type_analysis.count_tensors_in_type( iterative_process_returning_ones.initialize.type_signature. result.member.model)['parameters'])
def test_construction_calls_model_fn(self, server_optimzier): # Assert that the the process building does not call `model_fn` too many # times. `model_fn` can potentially be expensive (loading weights, # processing, etc). mock_model_fn = mock.Mock(side_effect=model_examples.LinearRegression) optimizer_utils.build_model_delta_optimizer_process( model_fn=mock_model_fn, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=server_optimzier()) # TODO(b/186451541): reduce the number of calls to model_fn. self.assertEqual(mock_model_fn.call_count, 3)
def test_fails_stateful_broadcast_and_process(self): model_weights_type = model_utils.weights_type_from_model( model_examples.LinearRegression) with self.assertRaises(optimizer_utils.DisjointArgumentError): optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.LinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=tf.keras.optimizers.SGD, stateful_model_broadcast_fn=computation_utils.StatefulBroadcastFn( initialize_fn=lambda: (), next_fn=lambda state, weights: # pylint: disable=g-long-lambda (state, intrinsics.federated_broadcast(weights))), broadcast_process=optimizer_utils.build_stateless_broadcaster( model_weights_type=model_weights_type))
def test_fails_stateful_aggregate_and_process(self): model_weights_type = model_utils.weights_type_from_model( model_examples.LinearRegression) with self.assertRaises(optimizer_utils.DisjointArgumentError): optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.LinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=tf.keras.optimizers.SGD, stateful_delta_aggregate_fn=computation_utils.StatefulAggregateFn( initialize_fn=lambda: (), next_fn=lambda state, value, weight=None: # pylint: disable=g-long-lambda (state, intrinsics.federated_mean(value, weight))), aggregation_process=optimizer_utils.build_stateless_mean( model_delta_type=model_weights_type.trainable))
def test_fails_stateful_broadcast_and_process(self): with tf.Graph().as_default(): model_weights_type = tff.framework.type_from_tensors( model_utils.ModelWeights.from_model( model_examples.LinearRegression())) with self.assertRaises(optimizer_utils.DisjointArgumentError): optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.LinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=tf.keras.optimizers.SGD, stateful_model_broadcast_fn=tff.utils.StatefulBroadcastFn( initialize_fn=lambda: (), next_fn=lambda state, weights: # pylint: disable=g-long-lambda (state, tff.federated_broadcast(weights))), broadcast_process=optimizer_utils.build_stateless_broadcaster( model_weights_type=model_weights_type))
def test_iterative_process_with_encoding(self): model_fn = model_examples.LinearRegression gather_fn = encoding_utils.build_encoded_mean_from_model( model_fn, _test_encoder_fn('gather')) broadcast_fn = encoding_utils.build_encoded_broadcast_from_model( model_fn, _test_encoder_fn('simple')) iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_fn, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate= 1.0), stateful_delta_aggregate_fn=gather_fn, stateful_model_broadcast_fn=broadcast_fn) ds = tf.data.Dataset.from_tensor_slices( collections.OrderedDict([ ('x', [[1.0, 2.0], [3.0, 4.0]]), ('y', [[5.0], [6.0]]), ])).batch(2) federated_ds = [ds] * 3 state = iterative_process.initialize() self.assertEqual(state.model_broadcast_state.trainable[0][0], 1) state, _ = iterative_process.next(state, federated_ds) self.assertEqual(state.model_broadcast_state.trainable[0][0], 2)
def test_construction_with_broadcast_process(self): model_weights_type = model_utils.weights_type_from_model( model_examples.LinearRegression) broadcast_process = _build_test_measured_broadcast(model_weights_type) iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.LinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=tf.keras.optimizers.SGD, broadcast_process=broadcast_process) expected_broadcast_state_type = broadcast_process.initialize.type_signature.result initialize_type = iterative_process.initialize.type_signature self.assertEqual( computation_types.FederatedType( initialize_type.result.member.model_broadcast_state, placements.SERVER), expected_broadcast_state_type) next_type = iterative_process.next.type_signature self.assertEqual( computation_types.FederatedType( next_type.parameter[0].member.model_broadcast_state, placements.SERVER), expected_broadcast_state_type) self.assertEqual( computation_types.FederatedType( next_type.result[0].member.model_broadcast_state, placements.SERVER), expected_broadcast_state_type)
def test_orchestration_execute(self): iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.TrainableLinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=lambda: gradient_descent.SGD(learning_rate=1.0 )) ds = tf.data.Dataset.from_tensor_slices({ 'x': [[1., 2.], [3., 4.]], 'y': [[5.], [6.]] }).batch(2) federated_ds = [ds] * 3 state = iterative_process.initialize() self.assertSequenceAlmostEqual(state.model.trainable.a, np.zeros([2, 1], np.float32)) self.assertAlmostEqual(state.model.trainable.b, 0.0) self.assertAlmostEqual(state.model.non_trainable.c, 0.0) state, outputs = iterative_process.next(state, federated_ds) self.assertSequenceAlmostEqual(state.model.trainable.a, -np.ones([2, 1], np.float32)) self.assertAlmostEqual(state.model.trainable.b, -1.0) self.assertAlmostEqual(state.model.non_trainable.c, 0.0) # Since all predictions are 0, loss is: # (0.5 * (0-5)^2 + (0-6)^2) / 2 = 15.25 self.assertAlmostEqual(outputs.loss, 15.25, places=4) # 3 clients * 2 examples per client = 6 examples. self.assertAlmostEqual(outputs.num_examples, 6.0, places=8)
def test_construction_with_aggregation_process(self, server_optimizer): model_update_type = model_utils.weights_type_from_model( model_examples.LinearRegression).trainable model_update_aggregator = TestMeasuredMeanFactory() iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.LinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=server_optimizer(), model_update_aggregation_factory=model_update_aggregator) agg_process = model_update_aggregator.create( model_update_type, computation_types.TensorType(tf.float32)) aggregation_state_type = agg_process.initialize.type_signature.result initialize_type = iterative_process.initialize.type_signature self.assertEqual( computation_types.FederatedType( initialize_type.result.member.delta_aggregate_state, placements.SERVER), aggregation_state_type) next_type = iterative_process.next.type_signature self.assertEqual( computation_types.FederatedType( next_type.parameter[0].member.delta_aggregate_state, placements.SERVER), aggregation_state_type) self.assertEqual( computation_types.FederatedType( next_type.result[0].member.delta_aggregate_state, placements.SERVER), aggregation_state_type) agg_metrics_type = agg_process.next.type_signature.result.measurements self.assertEqual( computation_types.FederatedType( next_type.result[1].member.aggregation, placements.SERVER), agg_metrics_type)
def build_federated_averaging_process( model_fn, server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0), client_weight_fn=None): """Builds the TFF computations for optimization using federated averaging. Args: model_fn: A no-arg function that returns a `tff.learning.TrainableModel`. server_optimizer_fn: A no-arg function that returns a `tf.Optimizer`. The `apply_gradients` method of this optimizer is used to apply client updates to the server model. The default creates a `tf.keras.optimizers.SGD` with a learning rate of 1.0, which simply adds the average client delta to the server's model. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor that provides the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. Returns: A `tff.utils.IterativeProcess`. """ def client_fed_avg(model_fn): return ClientFedAvg(model_fn(), client_weight_fn) return optimizer_utils.build_model_delta_optimizer_process( model_fn, client_fed_avg, server_optimizer_fn)
def test_construction_with_aggregation_process(self): model_update_type = model_utils.weights_type_from_model( model_examples.LinearRegression).trainable aggregation_process = _build_test_measured_mean(model_update_type) iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.LinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=tf.keras.optimizers.SGD, aggregation_process=aggregation_process) aggregation_state_type = aggregation_process.initialize.type_signature.result initialize_type = iterative_process.initialize.type_signature self.assertEqual( computation_types.FederatedType( initialize_type.result.member.delta_aggregate_state, placements.SERVER), aggregation_state_type) next_type = iterative_process.next.type_signature self.assertEqual( computation_types.FederatedType( next_type.parameter[0].member.delta_aggregate_state, placements.SERVER), aggregation_state_type) self.assertEqual( computation_types.FederatedType( next_type.result[0].member.delta_aggregate_state, placements.SERVER), aggregation_state_type) aggregation_metrics_type = aggregation_process.next.type_signature.result.measurements self.assertEqual( computation_types.FederatedType( next_type.result[1].member.aggregation, placements.SERVER), aggregation_metrics_type)
def test_orchestration_execute_measured_process(self, server_optimizer): model_weights_type = model_utils.weights_type_from_model( model_examples.LinearRegression) learning_rate = 1.0 server_optimizer_fn = server_optimizer(learning_rate) iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.LinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=server_optimizer_fn, broadcast_process=_build_test_measured_broadcast( model_weights_type), model_update_aggregation_factory=TestMeasuredMeanFactory()) client_dataset = tf.data.Dataset.from_tensor_slices( collections.OrderedDict([ ('x', [[1.0, 2.0], [3.0, 4.0]]), ('y', [[5.0], [6.0]]), ])).batch(2) federated_dataset = [client_dataset] * 3 state = iterative_process.initialize() if callable(server_optimizer_fn): # Keras SGD keeps track of a single scalar for the number of iterations. self.assertAllEqual(state.optimizer_state, [0]) else: # TFF SGD stores learning rate in state. self.assertAllClose( state.optimizer_state, collections.OrderedDict([(optimizer.LEARNING_RATE_KEY, learning_rate)])) self.assertAllClose(list(state.model.trainable), [np.zeros((2, 1)), 0.0]) self.assertAllClose(list(state.model.non_trainable), [0.0]) self.assertEqual(state.delta_aggregate_state, 0) self.assertEqual(state.model_broadcast_state, 0) state, outputs = iterative_process.next(state, federated_dataset) self.assertAllClose( # `DummyClientDeltaFn` always sends fake model weights deltas (negative # ones) back. Because the initial model weights are all zeros, the # updated model weights will be all negative ones. list(state.model.trainable), [-np.ones((2, 1)), -1.0]) self.assertAllClose(list(state.model.non_trainable), [0.0]) if callable(server_optimizer_fn): self.assertAllEqual(state.optimizer_state, [1]) self.assertEqual(state.delta_aggregate_state, 1) self.assertEqual(state.model_broadcast_state, 1) expected_outputs = collections.OrderedDict( # `_build_test_measured_broadcast` builds a broadcast process whose # `measurements` is 3.0 + norm of initial model weights (which is zero). broadcast=3.0, aggregation=collections.OrderedDict(num_clients=3), train=collections.OrderedDict( # The average mean squared loss is computed at the initial model # weights (i.e., at zero weights): 0.5*(25+36)/2 = 15.25. loss=15.25, num_examples=6)) self.assertAllEqual(expected_outputs, outputs)
def build_federated_averaging_process( model_fn: Callable[[], model_lib.Model], client_optimizer_fn: Callable[[], tf.keras.optimizers.Optimizer], server_optimizer_fn: Callable[ [], tf.keras.optimizers.Optimizer] = DEFAULT_SERVER_OPTIMIZER_FN, client_weight_fn: Callable[[Any], tf.Tensor] = None, stateful_delta_aggregate_fn=None, stateful_model_broadcast_fn=None) -> tff.utils.IterativeProcess: """Builds the TFF computations for optimization using federated averaging. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. client_optimizer_fn: A no-arg callable that returns a `tf.keras.Optimizer`. server_optimizer_fn: A no-arg callable that returns a `tf.keras.Optimizer`. The `apply_gradients` method of this optimizer is used to apply client updates to the server model. The default creates a `tf.keras.optimizers.SGD` with a learning rate of 1.0, which simply adds the average client delta to the server's model. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor that provides the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. stateful_delta_aggregate_fn: A `tff.utils.StatefulAggregateFn` where the `next_fn` performs a federated aggregation and upates state. That is, it has TFF type `(state@SERVER, value@CLIENTS, weights@CLIENTS) -> (state@SERVER, aggregate@SERVER)`, where the `value` type is `tff.learning.framework.ModelWeights.trainable` corresponding to the object returned by `model_fn`. By default performs arithmetic mean aggregation, weighted by `client_weight_fn`. stateful_model_broadcast_fn: A `tff.utils.StatefulBroadcastFn` where the `next_fn` performs a federated broadcast and upates state. That is, it has TFF type `(state@SERVER, value@SERVER) -> (state@SERVER, value@CLIENTS)`, where the `value` type is `tff.learning.framework.ModelWeights` corresponding to the object returned by `model_fn`. By default performs identity broadcast. Returns: A `tff.utils.IterativeProcess`. """ def client_fed_avg(model_fn): return _ClientFedAvg(model_fn(), client_optimizer_fn(), client_weight_fn) if stateful_delta_aggregate_fn is None: stateful_delta_aggregate_fn = optimizer_utils.build_stateless_mean() else: py_typecheck.check_type(stateful_delta_aggregate_fn, tff.utils.StatefulAggregateFn) if stateful_model_broadcast_fn is None: stateful_model_broadcast_fn = optimizer_utils.build_stateless_broadcaster() else: py_typecheck.check_type(stateful_model_broadcast_fn, tff.utils.StatefulBroadcastFn) return optimizer_utils.build_model_delta_optimizer_process( model_fn, client_fed_avg, server_optimizer_fn, stateful_delta_aggregate_fn, stateful_model_broadcast_fn)
def test_construction(self, weighted): aggregation_factory = (mean.MeanFactory() if weighted else sum_factory.SumFactory()) iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.LinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=tf.keras.optimizers.SGD, model_update_aggregation_factory=aggregation_factory) if weighted: aggregate_state = collections.OrderedDict(value_sum_process=(), weight_sum_process=()) aggregate_metrics = collections.OrderedDict(mean_value=(), mean_weight=()) else: aggregate_state = () aggregate_metrics = () server_state_type = computation_types.FederatedType( optimizer_utils.ServerState(model=model_utils.ModelWeights( trainable=[ computation_types.TensorType(tf.float32, [2, 1]), computation_types.TensorType(tf.float32) ], non_trainable=[computation_types.TensorType(tf.float32)]), optimizer_state=[tf.int64], delta_aggregate_state=aggregate_state, model_broadcast_state=()), placements.SERVER) self.assert_types_equivalent( computation_types.FunctionType(parameter=None, result=server_state_type), iterative_process.initialize.type_signature) dataset_type = computation_types.FederatedType( computation_types.SequenceType( collections.OrderedDict( x=computation_types.TensorType(tf.float32, [None, 2]), y=computation_types.TensorType(tf.float32, [None, 1]))), placements.CLIENTS) metrics_type = computation_types.FederatedType( collections.OrderedDict( broadcast=(), aggregation=aggregate_metrics, train=collections.OrderedDict( loss=computation_types.TensorType(tf.float32), num_examples=computation_types.TensorType(tf.int32)), stat=collections.OrderedDict( num_examples=computation_types.TensorType(tf.float32))), placements.SERVER) self.assert_types_equivalent( computation_types.FunctionType(parameter=collections.OrderedDict( server_state=server_state_type, federated_dataset=dataset_type, ), result=(server_state_type, metrics_type)), iterative_process.next.type_signature)
def test_iterative_process_with_encoding(self): model_fn = model_examples.LinearRegression gather_fn = encoding_utils.build_encoded_mean_from_model( model_fn, _test_encoder_fn('gather')) broadcast_fn = encoding_utils.build_encoded_broadcast_from_model( model_fn, _test_encoder_fn('simple')) iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_fn, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0), stateful_delta_aggregate_fn=gather_fn, stateful_model_broadcast_fn=broadcast_fn) self._verify_iterative_process(iterative_process)
def test_construction_with_adam_optimizer(self): iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.LinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=tf.keras.optimizers.Adam) # Assert that the optimizer_state includes the 5 variables (scalar for # # of iterations, plus two copies of the kernel and bias in the model). initialize_type = iterative_process.initialize.type_signature self.assertLen(initialize_type.result.member.optimizer_state, 5) next_type = iterative_process.next.type_signature self.assertLen(next_type.parameter[0].member.optimizer_state, 5) self.assertLen(next_type.result[0].member.optimizer_state, 5)
def test_orchestration_execute_measured_process(self, server_optimizer): model_weights_type = model_utils.weights_type_from_model( model_examples.LinearRegression) learning_rate = 1.0 server_optimizer_fn = server_optimizer(learning_rate) iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.LinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=server_optimizer_fn, broadcast_process=_build_test_measured_broadcast( model_weights_type), model_update_aggregation_factory=TestMeasuredMeanFactory()) ds = tf.data.Dataset.from_tensor_slices( collections.OrderedDict([ ('x', [[1.0, 2.0], [3.0, 4.0]]), ('y', [[5.0], [6.0]]), ])).batch(2) federated_ds = [ds] * 3 state = iterative_process.initialize() if callable(server_optimizer_fn): # Keras SGD keeps track of a single scalar for the number of iterations. self.assertAllEqual(state.optimizer_state, [0]) else: # TFF SGD has an empty state. self.assertEmpty(state.optimizer_state) self.assertAllClose(list(state.model.trainable), [np.zeros((2, 1)), 0.0]) self.assertAllClose(list(state.model.non_trainable), [0.0]) self.assertEqual(state.delta_aggregate_state, 0) self.assertEqual(state.model_broadcast_state, 0) state, outputs = iterative_process.next(state, federated_ds) self.assertAllClose(list(state.model.trainable), [-np.ones((2, 1)), -1.0 * learning_rate]) self.assertAllClose(list(state.model.non_trainable), [0.0]) if callable(server_optimizer_fn): self.assertAllEqual(state.optimizer_state, [1]) self.assertEqual(state.delta_aggregate_state, 1) self.assertEqual(state.model_broadcast_state, 1) expected_outputs = collections.OrderedDict( broadcast=3.0, aggregation=collections.OrderedDict(num_clients=3), train={ 'loss': 15.25, 'num_examples': 6, }, stat=collections.OrderedDict(num_examples=3.0)) self.assertAllEqual(expected_outputs, outputs)
def test_orchestration_type_signature(self): iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.TrainableLinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=lambda: gradient_descent.SGD(learning_rate=1.0 )) expected_model_weights_type = model_utils.ModelWeights( collections.OrderedDict([('a', tff.TensorType(tf.float32, [2, 1])), ('b', tf.float32)]), collections.OrderedDict([('c', tf.float32)])) # ServerState consists of a model and optimizer_state. The optimizer_state # is provided by TensorFlow, TFF doesn't care what the actual value is. expected_federated_server_state_type = tff.FederatedType( optimizer_utils.ServerState(expected_model_weights_type, test.AnyType(), test.AnyType(), test.AnyType()), placement=tff.SERVER, all_equal=True) expected_federated_dataset_type = tff.FederatedType(tff.SequenceType( model_examples.TrainableLinearRegression().input_spec), tff.CLIENTS, all_equal=False) expected_model_output_types = tff.FederatedType( collections.OrderedDict([ ('loss', tff.TensorType(tf.float32, [])), ('num_examples', tff.TensorType(tf.int32, [])), ]), tff.SERVER, all_equal=True) # `initialize` is expected to be a funcion of no arguments to a ServerState. self.assertEqual( tff.FunctionType(parameter=None, result=expected_federated_server_state_type), iterative_process.initialize.type_signature) # `next` is expected be a function of (ServerState, Datasets) to # ServerState. self.assertEqual( tff.FunctionType(parameter=[ expected_federated_server_state_type, expected_federated_dataset_type ], result=(expected_federated_server_state_type, expected_model_output_types)), iterative_process.next.type_signature)
def test_orchestration_execute_measured_process(self): with tf.Graph().as_default(): model_weights_type = tff.framework.type_from_tensors( model_utils.ModelWeights.from_model( model_examples.LinearRegression())) learning_rate = 1.0 iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.LinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=functools.partial(tf.keras.optimizers.SGD, learning_rate=learning_rate), broadcast_process=_build_test_measured_broadcast( model_weights_type), aggregation_process=_build_test_measured_mean( model_weights_type.trainable)) ds = tf.data.Dataset.from_tensor_slices( collections.OrderedDict([ ('x', [[1.0, 2.0], [3.0, 4.0]]), ('y', [[5.0], [6.0]]), ])).batch(2) federated_ds = [ds] * 3 state = iterative_process.initialize() # SGD keeps track of a single scalar for the number of iterations. self.assertAllEqual(state.optimizer_state, [0]) self.assertAllClose(list(state.model.trainable), [np.zeros((2, 1)), 0.0]) self.assertAllClose(list(state.model.non_trainable), [0.0]) self.assertEqual(state.delta_aggregate_state, 0) self.assertEqual(state.model_broadcast_state, 0) state, outputs = iterative_process.next(state, federated_ds) self.assertAllClose(list(state.model.trainable), [-np.ones((2, 1)), -1.0 * learning_rate]) self.assertAllClose(list(state.model.non_trainable), [0.0]) self.assertAllEqual(state.optimizer_state, [1]) self.assertEqual(state.delta_aggregate_state, 1) self.assertEqual(state.model_broadcast_state, 1) expected_outputs = anonymous_tuple.from_container( collections.OrderedDict( broadcast=3.0, aggregation=collections.OrderedDict(num_clients=3), train=collections.OrderedDict( loss=15.25, num_examples=6, )), recursive=True) self.assertEqual(expected_outputs, outputs)
def test_orchestration_execute_stateful_fn(self): learning_rate = 1.0 iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.LinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=functools.partial(tf.keras.optimizers.SGD, learning_rate=learning_rate), # A federated_mean that maintains an int32 state equal to the # number of times the federated_mean has been executed, # allowing us to test that a stateful aggregator's state # is properly updated. stateful_delta_aggregate_fn=state_incrementing_mean, # Similarly, a broadcast with state that increments: stateful_model_broadcast_fn=state_incrementing_broadcaster) ds = tf.data.Dataset.from_tensor_slices( collections.OrderedDict([ ('x', [[1.0, 2.0], [3.0, 4.0]]), ('y', [[5.0], [6.0]]), ])).batch(2) federated_ds = [ds] * 3 state = iterative_process.initialize() # SGD keeps track of a single scalar for the number of iterations. self.assertAllEqual(state.optimizer_state, [0]) self.assertAllClose(list(state.model.trainable), [np.zeros((2, 1)), 0.0]) self.assertAllClose(list(state.model.non_trainable), [0.0]) self.assertEqual(state.delta_aggregate_state, 0) self.assertEqual(state.model_broadcast_state, 0) state, outputs = iterative_process.next(state, federated_ds) self.assertAllClose(list(state.model.trainable), [-np.ones((2, 1)), -1.0 * learning_rate]) self.assertAllClose(list(state.model.non_trainable), [0.0]) self.assertAllEqual(state.optimizer_state, [1]) self.assertEqual(state.delta_aggregate_state, 1) self.assertEqual(state.model_broadcast_state, 1) expected_outputs = anonymous_tuple.from_container( collections.OrderedDict( # The StatefulFn output no metrics. broadcast=(), aggregation=(), train=collections.OrderedDict( loss=15.25, num_examples=6, )), recursive=True) self.assertEqual(expected_outputs, outputs)
def test_construction(self): iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.LinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=tf.keras.optimizers.SGD) server_state_type = computation_types.FederatedType( optimizer_utils.ServerState( model=model_utils.ModelWeights( trainable=[ computation_types.TensorType(tf.float32, [2, 1]), computation_types.TensorType(tf.float32) ], non_trainable=[computation_types.TensorType(tf.float32)]), optimizer_state=[tf.int64], delta_aggregate_state=(), model_broadcast_state=()), placements.SERVER) self.assertEqual( str(iterative_process.initialize.type_signature), str( computation_types.FunctionType( parameter=None, result=server_state_type))) dataset_type = computation_types.FederatedType( computation_types.SequenceType( collections.OrderedDict( x=computation_types.TensorType(tf.float32, [None, 2]), y=computation_types.TensorType(tf.float32, [None, 1]))), placements.CLIENTS) metrics_type = computation_types.FederatedType( collections.OrderedDict( broadcast=(), aggregation=(), train=collections.OrderedDict( loss=computation_types.TensorType(tf.float32), num_examples=computation_types.TensorType(tf.int32))), placements.SERVER) self.assertEqual( str(iterative_process.next.type_signature), str( computation_types.FunctionType( parameter=collections.OrderedDict( server_state=server_state_type, federated_dataset=dataset_type, ), result=(server_state_type, metrics_type))))
def test_construction_with_tff_sgdm_optimizer(self, momentum, state_len): iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.LinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=sgdm.build_sgdm(learning_rate=0.1, momentum=momentum)) # Assert that the optimizer_state is empty for SGD without momentum; # and includes two tensors for the momentum: kernel and bias from the model. initialize_type = iterative_process.initialize.type_signature self.assertLen(initialize_type.result.member.optimizer_state, state_len) next_type = iterative_process.next.type_signature self.assertLen(next_type.parameter[0].member.optimizer_state, state_len) self.assertLen(next_type.result[0].member.optimizer_state, state_len)
def test_contruction_with_broadcast_state(self): iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.LinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=tf.keras.optimizers.SGD, stateful_model_broadcast_fn=state_incrementing_broadcaster) expected_broadcast_state_type = tff.TensorType(tf.int32) initialize_type = iterative_process.initialize.type_signature self.assertEqual(initialize_type.result.member.model_broadcast_state, expected_broadcast_state_type) next_type = iterative_process.next.type_signature self.assertEqual(next_type.parameter[0].member.model_broadcast_state, expected_broadcast_state_type) self.assertEqual(next_type.result[0].member.model_broadcast_state, expected_broadcast_state_type)
def test_orchestration_execute_sgd(self): learning_rate = 1.0 iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.LinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=functools.partial(tf.keras.optimizers.SGD, learning_rate=learning_rate), # A federated_mean that maintains an int32 state equal to the # number of times the federated_mean has been executed, # allowing us to test that a stateful aggregator's state # is properly updated. stateful_delta_aggregate_fn=state_incrementing_mean, # Similarly, a broadcast with state that increments: stateful_model_broadcast_fn=state_incrementing_broadcaster) ds = tf.data.Dataset.from_tensor_slices( collections.OrderedDict([ ('x', [[1.0, 2.0], [3.0, 4.0]]), ('y', [[5.0], [6.0]]), ])).batch(2) federated_ds = [ds] * 3 state = iterative_process.initialize() # SGD keeps track of a single scalar for the number of iterations. self.assertAllEqual(state.optimizer_state, [0]) self.assertAllClose(list(state.model.trainable), [np.zeros((2, 1)), 0.0]) self.assertAllClose(list(state.model.non_trainable), [0.0]) self.assertEqual(state.delta_aggregate_state, 0) self.assertEqual(state.model_broadcast_state, 0) state, outputs = iterative_process.next(state, federated_ds) self.assertAllClose(list(state.model.trainable), [-np.ones((2, 1)), -1.0 * learning_rate]) self.assertAllClose(list(state.model.non_trainable), [0.0]) self.assertAllEqual(state.optimizer_state, [1]) self.assertEqual(state.delta_aggregate_state, 1) self.assertEqual(state.model_broadcast_state, 1) # Since all predictions are 0, loss is: # (0.5 * (0-5)^2 + (0-6)^2) / 2 = 15.25 self.assertAlmostEqual(outputs.loss, 15.25, places=4) # 3 clients * 2 examples per client = 6 examples. self.assertAlmostEqual(outputs.num_examples, 6.0, places=8)
def test_orchestration_execute(self): iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.TrainableLinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate= 1.0), # A federated_mean that maintains an int32 state equal to the # number of times the federated_mean has been executed, # allowing us to test that a stateful aggregator's state # is properly updated. stateful_delta_aggregate_fn=state_incrementing_mean, # Similarly, a broadcast with state that increments: stateful_model_broadcast_fn=state_incrementing_broadcaster) ds = tf.data.Dataset.from_tensor_slices( collections.OrderedDict([ ('x', [[1.0, 2.0], [3.0, 4.0]]), ('y', [[5.0], [6.0]]), ])).batch(2) federated_ds = [ds] * 3 state = iterative_process.initialize() self.assertSequenceAlmostEqual(state.model.trainable.a, np.zeros([2, 1], np.float32)) self.assertAlmostEqual(state.model.trainable.b, 0.0) self.assertAlmostEqual(state.model.non_trainable.c, 0.0) self.assertEqual(state.delta_aggregate_state, 0) self.assertEqual(state.model_broadcast_state, 0) state, outputs = iterative_process.next(state, federated_ds) self.assertSequenceAlmostEqual(state.model.trainable.a, -np.ones([2, 1], np.float32)) self.assertAlmostEqual(state.model.trainable.b, -1.0) self.assertAlmostEqual(state.model.non_trainable.c, 0.0) self.assertEqual(state.delta_aggregate_state, 1) self.assertEqual(state.model_broadcast_state, 1) # Since all predictions are 0, loss is: # (0.5 * (0-5)^2 + (0-6)^2) / 2 = 15.25 self.assertAlmostEqual(outputs.loss, 15.25, places=4) # 3 clients * 2 examples per client = 6 examples. self.assertAlmostEqual(outputs.num_examples, 6.0, places=8)
def test_iterative_process_with_encoding(self): model_fn = model_examples.TrainableLinearRegression broadcast_fn = encoding_utils.build_encoded_broadcast_from_model( model_fn, _test_encoder_fn()) iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_fn, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0), stateful_model_broadcast_fn=broadcast_fn) ds = tf.data.Dataset.from_tensor_slices({ 'x': [[1., 2.], [3., 4.]], 'y': [[5.], [6.]] }).batch(2) federated_ds = [ds] * 3 state = iterative_process.initialize() self.assertEqual(state.model_broadcast_state.trainable.a[0], 1) state, _ = iterative_process.next(state, federated_ds) self.assertEqual(state.model_broadcast_state.trainable.a[0], 2)
def build_federated_process_for_test(model_fn, num_passes=5, tolerance=1e-6): """Build a test FedAvg process with a dummy client computation. Analogue of `build_federated_averaging_process`, but with client_fed_avg replaced by the dummy mean computation defined above. Args: model_fn: callable that returns a `tff.learning.Model`. num_passes: integer number of communication rounds in the smoothed Weiszfeld algorithm (min. 1). tolerance: float smoothing parameter of smoothed Weiszfeld algorithm. Default 1e-6. Returns: A `tff.utils.IterativeProcess`. """ server_optimizer_fn = lambda: tf.keras.optimizers.SGD(learning_rate=1.0) def client_fed_avg(model_fn): return DummyClientComputation(model_fn(), client_weight_fn=None) # Build robust aggregation function with tf.Graph().as_default(): # workaround since keras automatically appends "_n" to the nth call of # `model_fn` model_type = tff.framework.type_from_tensors( model_fn().weights.trainable) stateful_delta_aggregate_fn = rfa.build_stateless_robust_aggregation( model_type, num_communication_passes=num_passes, tolerance=tolerance) stateful_model_broadcast_fn = optimizer_utils.build_stateless_broadcaster( ) return optimizer_utils.build_model_delta_optimizer_process( model_fn, client_fed_avg, server_optimizer_fn, stateful_delta_aggregate_fn, stateful_model_broadcast_fn)
def build_federated_sgd_process( model_fn: Callable[[], model_lib.Model], server_optimizer_fn: Callable[ [], tf.keras.optimizers.Optimizer] = DEFAULT_SERVER_OPTIMIZER_FN, client_weight_fn: Callable[[Any], tf.Tensor] = None, *, broadcast_process: Optional[measured_process.MeasuredProcess] = None, aggregation_process: Optional[measured_process.MeasuredProcess] = None, model_update_aggregation_factory: Optional[ factory.AggregationProcessFactory] = None, use_experimental_simulation_loop: bool = False, ) -> iterative_process.IterativeProcess: """Builds the TFF computations for optimization using federated SGD. This function creates a `tff.templates.IterativeProcess` that performs federated averaging on client models. The iterative process has the following methods: * `initialize`: A `tff.Computation` with the functional type signature `( -> S@SERVER)`, where `S` is a `tff.learning.framework.ServerState` representing the initial state of the server. * `next`: A `tff.Computation` with the functional type signature `(<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>)` where `S` is a `tff.learning.framework.ServerState` whose type matches that of the output of `initialize`, and `{B*}@CLIENTS` represents the client datasets, where `B` is the type of a single batch. This computation returns a `tff.learning.framework.ServerState` representing the updated server state and metrics that are the result of `tff.learning.Model.federated_output_computation` during client training and any other metrics from broadcast and aggregation processes. Each time the `next` method is called, the server model is broadcast to each client using a broadcast function. Each client sums the gradients at each batch in the client's local dataset. These gradient sums are then aggregated at the server using an aggregation function. The aggregate gradients are applied at the server by using the `tf.keras.optimizers.Optimizer.apply_gradients` method of the server optimizer. This implements the original FedSGD algorithm in [McMahan et al., 2017](https://arxiv.org/abs/1602.05629). Note: the default server optimizer function is `tf.keras.optimizers.SGD` with a learning rate of 0.1. More sophisticated federated SGD procedures may use different learning rates or server optimizers. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. This method must *not* capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error. server_optimizer_fn: A no-arg function that returns a `tf.Optimizer`. The `apply_gradients` method of this optimizer is used to apply client updates to the server model. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor that provides the weight in the federated average of the aggregated gradients. If not provided, the default is the total number of examples processed on device. broadcast_process: a `tff.templates.MeasuredProcess` that broadcasts the model weights on the server to the clients. It must support the signature `(input_values@SERVER -> output_values@CLIENT)`. aggregation_process: a `tff.templates.MeasuredProcess` that aggregates the model updates on the clients back to the server. It must support the signature `({input_values}@CLIENTS-> output_values@SERVER)`. Must be `None` if `model_update_aggregation_factory` is not `None.` model_update_aggregation_factory: An optional `tff.aggregators.AggregationProcessFactory` that constructs `tff.templates.AggregationProcess` for aggregating the client model updates on the server. If `None`, uses a default constructed `tff.aggregators.MeanFactory`, creating a stateless mean aggregation. Must be `None` if `aggregation_process` is not `None.` use_experimental_simulation_loop: Controls the reduce loop function for input dataset. An experimental reduce loop is used for simulation. Returns: A `tff.templates.IterativeProcess`. """ def client_sgd_avg(model_fn: Callable[[], model_lib.Model]) -> ClientSgd: return ClientSgd( model_fn(), client_weight_fn, use_experimental_simulation_loop=use_experimental_simulation_loop) return optimizer_utils.build_model_delta_optimizer_process( model_fn, model_to_client_delta_fn=client_sgd_avg, server_optimizer_fn=server_optimizer_fn, broadcast_process=broadcast_process, aggregation_process=aggregation_process, model_update_aggregation_factory=model_update_aggregation_factory)
def test_execute_measured_process_with_custom_metrics_aggregator( self, server_optimizer): model_weights_type = model_utils.weights_type_from_model( model_examples.LinearRegression) learning_rate = 1.0 server_optimizer_fn = server_optimizer(learning_rate) def custom_metrics_aggregator(metric_finalizers, local_unfinalized_metrics_type): """Builds a TFF computation that computes per-client min/max metrics.""" @tensorflow_computation.tf_computation( local_unfinalized_metrics_type) def finalizer_computation(unfinalized_metrics): finalized_metrics = collections.OrderedDict() for metric_name, finalizer in metric_finalizers.items(): finalized_metrics[metric_name] = finalizer( unfinalized_metrics[metric_name]) return finalized_metrics @federated_computation.federated_computation( computation_types.at_clients(local_unfinalized_metrics_type)) def aggregator_computation(client_local_unfinalized_metrics): client_local_finalized_metrics = intrinsics.federated_map( finalizer_computation, client_local_unfinalized_metrics) aggregated_metrics = collections.OrderedDict( loss_per_client_max=primitives.federated_max( client_local_finalized_metrics['loss']), loss_per_client_min=primitives.federated_min( client_local_finalized_metrics['loss']), num_examples_per_client_max=primitives.federated_max( client_local_finalized_metrics['num_examples']), num_examples_per_client_min=primitives.federated_min( client_local_finalized_metrics['num_examples'])) return intrinsics.federated_zip(aggregated_metrics) return aggregator_computation iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.LinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=server_optimizer_fn, broadcast_process=_build_test_measured_broadcast( model_weights_type), model_update_aggregation_factory=TestMeasuredMeanFactory(), metrics_aggregator=custom_metrics_aggregator) # The first client has 1 example. The second client has 2 examples. client_1_dataset = tf.data.Dataset.from_tensor_slices( collections.OrderedDict([ ('x', [[-1.0, -1.0]]), ('y', [[1.0]]), ])).batch(1) client_2_dataset = tf.data.Dataset.from_tensor_slices( collections.OrderedDict([ ('x', [[1.0, 2.0], [3.0, 4.0]]), ('y', [[5.0], [6.0]]), ])).batch(2) federated_dataset = [client_1_dataset, client_2_dataset] state = iterative_process.initialize() if callable(server_optimizer_fn): # Keras SGD keeps track of a single scalar for the number of iterations. self.assertAllEqual(state.optimizer_state, [0]) else: # TFF SGD stores learning rate in state. self.assertAllClose( state.optimizer_state, collections.OrderedDict([(optimizer.LEARNING_RATE_KEY, learning_rate)])) self.assertAllClose(list(state.model.trainable), [np.zeros((2, 1)), 0.0]) self.assertAllClose(list(state.model.non_trainable), [0.0]) self.assertEqual(state.delta_aggregate_state, 0) self.assertEqual(state.model_broadcast_state, 0) state, outputs = iterative_process.next(state, federated_dataset) self.assertAllClose( # `DummyClientDeltaFn` always sends fake model weights deltas (negative # ones) back. Because the initial model weights are all zeros, the # updated model weights will be all negative ones. list(state.model.trainable), [-np.ones((2, 1)), -1.0]) self.assertAllClose(list(state.model.non_trainable), [0.0]) if callable(server_optimizer_fn): self.assertAllEqual(state.optimizer_state, [1]) self.assertEqual(state.delta_aggregate_state, 1) self.assertEqual(state.model_broadcast_state, 1) expected_outputs = collections.OrderedDict( # `_build_test_measured_broadcast` builds a broadcast process whose # `measurements` is 3.0 + norm of initial model weights (which is zero). broadcast=3.0, aggregation=collections.OrderedDict(num_clients=2), train=collections.OrderedDict( # The average mean squared loss is computed at the initial model # weights (i.e., at zeros weights): the loss is 0.5 for the first # client and is 0.5*(25+36)/2 = 15.25 for the second client. loss_per_client_max=15.25, loss_per_client_min=0.5, num_examples_per_client_max=2, num_examples_per_client_min=1)) self.assertAllClose(expected_outputs, outputs)