def test_weighted_fed_avg_raises_on_unweighted_aggregator(self): model_aggregator = model_update_aggregator.robust_aggregator( weighted=False) with self.assertRaisesRegex(TypeError, 'WeightedAggregationFactory'): fed_avg.build_weighted_fed_avg( model_fn=model_examples.LinearRegression, client_optimizer_fn=sgdm.build_sgdm(1.0), model_aggregator=model_aggregator)
def test_client_tf_dataset_reduce_fn(self, simulation, mock_method): fed_avg.build_weighted_fed_avg( model_fn=model_examples.LinearRegression, client_optimizer_fn=sgdm.build_sgdm(1.0), use_experimental_simulation_loop=simulation) if simulation: mock_method.assert_not_called() else: mock_method.assert_called()
def test_construction_calls_model_fn(self, optimizer_fn, aggregation_factory): # Assert that the process building does not call `model_fn` too many times. # `model_fn` can potentially be expensive (loading weights, processing, etc # ). mock_model_fn = mock.Mock(side_effect=model_examples.LinearRegression) fed_avg.build_weighted_fed_avg(model_fn=mock_model_fn, client_optimizer_fn=optimizer_fn, model_aggregator=aggregation_factory()) self.assertEqual(mock_model_fn.call_count, 3)
def test_raises_on_invalid_distributor(self): model_weights_type = type_conversions.type_from_tensors( model_utils.ModelWeights.from_model( model_examples.LinearRegression())) distributor = distributors.build_broadcast_process(model_weights_type) invalid_distributor = iterative_process.IterativeProcess( distributor.initialize, distributor.next) with self.assertRaises(TypeError): fed_avg.build_weighted_fed_avg( model_fn=model_examples.LinearRegression, client_optimizer_fn=sgdm.build_sgdm(1.0), model_distributor=invalid_distributor)
def test_equivalent_to_vanilla_fed_avg(self): # Mime Lite with no-momentum SGD should reduce to FedAvg. mime_process = mime.build_weighted_mime_lite( model_fn=_create_model, base_optimizer=sgdm.build_sgdm(0.1)) fed_avg_process = fed_avg.build_weighted_fed_avg( model_fn=_create_model, client_optimizer_fn=sgdm.build_sgdm(0.1)) client_data = [_create_dataset()] mime_state = mime_process.initialize() fed_avg_state = fed_avg_process.initialize() for _ in range(3): mime_output = mime_process.next(mime_state, client_data) mime_state = mime_output.state mime_metrics = mime_output.metrics fed_avg_output = fed_avg_process.next(fed_avg_state, client_data) fed_avg_state = fed_avg_output.state fed_avg_metrics = fed_avg_output.metrics self.assertAllClose( tf.nest.flatten(mime_process.get_model_weights(mime_state)), tf.nest.flatten( fed_avg_process.get_model_weights(fed_avg_state))) self.assertAllClose( mime_metrics['client_work']['train']['loss'], fed_avg_metrics['client_work']['train']['loss']) self.assertAllClose( mime_metrics['client_work']['train']['num_examples'], fed_avg_metrics['client_work']['train']['num_examples'])
def test_weighted_fed_avg_with_only_secure_aggregation(self): model_fn = model_examples.LinearRegression learning_process = fed_avg.build_weighted_fed_avg( model_fn, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(1.0), model_aggregator=model_update_aggregator.secure_aggregator( weighted=True), metrics_aggregator=aggregator.secure_sum_then_finalize) static_assert.assert_not_contains_unsecure_aggregation( learning_process.next)
def test_weighted_fed_avg_with_only_secure_aggregation(self): model = test_models.build_functional_linear_regression() learning_process = fed_avg.build_weighted_fed_avg( model_fn=None, model=model, client_optimizer_fn=sgdm.build_sgdm(learning_rate=0.1), model_aggregator=model_update_aggregator.secure_aggregator( weighted=True), metrics_aggregator=aggregator.secure_sum_then_finalize) static_assert.assert_not_contains_unsecure_aggregation( learning_process.next)
def test_raises_on_non_callable_model_fn(self): with self.assertRaises(TypeError): fed_avg.build_weighted_fed_avg( model_fn=model_examples.LinearRegression(), client_optimizer_fn=tf.keras.optimizers.SGD)
def test_raises_on_model_and_model_fn(self): with self.assertRaises(ValueError): fed_avg.build_weighted_fed_avg( model_fn=model_examples.LinearRegression, model=test_models.build_functional_linear_regression(), client_optimizer_fn=sgdm.build_sgdm(learning_rate=0.1))
def test_raises_on_invalid_client_weighting(self): with self.assertRaises(TypeError): fed_avg.build_weighted_fed_avg( model_fn=model_examples.LinearRegression, client_optimizer_fn=sgdm.build_sgdm(1.0), client_weighting='uniform')