示例#1
0
 def test_weighted_fed_avg_raises_on_unweighted_aggregator(self):
     model_aggregator = model_update_aggregator.robust_aggregator(
         weighted=False)
     with self.assertRaisesRegex(TypeError, 'WeightedAggregationFactory'):
         fed_avg.build_weighted_fed_avg(
             model_fn=model_examples.LinearRegression,
             client_optimizer_fn=sgdm.build_sgdm(1.0),
             model_aggregator=model_aggregator)
示例#2
0
 def test_client_tf_dataset_reduce_fn(self, simulation, mock_method):
     fed_avg.build_weighted_fed_avg(
         model_fn=model_examples.LinearRegression,
         client_optimizer_fn=sgdm.build_sgdm(1.0),
         use_experimental_simulation_loop=simulation)
     if simulation:
         mock_method.assert_not_called()
     else:
         mock_method.assert_called()
示例#3
0
 def test_construction_calls_model_fn(self, optimizer_fn,
                                      aggregation_factory):
     # Assert that the process building does not call `model_fn` too many times.
     # `model_fn` can potentially be expensive (loading weights, processing, etc
     # ).
     mock_model_fn = mock.Mock(side_effect=model_examples.LinearRegression)
     fed_avg.build_weighted_fed_avg(model_fn=mock_model_fn,
                                    client_optimizer_fn=optimizer_fn,
                                    model_aggregator=aggregation_factory())
     self.assertEqual(mock_model_fn.call_count, 3)
示例#4
0
 def test_raises_on_invalid_distributor(self):
     model_weights_type = type_conversions.type_from_tensors(
         model_utils.ModelWeights.from_model(
             model_examples.LinearRegression()))
     distributor = distributors.build_broadcast_process(model_weights_type)
     invalid_distributor = iterative_process.IterativeProcess(
         distributor.initialize, distributor.next)
     with self.assertRaises(TypeError):
         fed_avg.build_weighted_fed_avg(
             model_fn=model_examples.LinearRegression,
             client_optimizer_fn=sgdm.build_sgdm(1.0),
             model_distributor=invalid_distributor)
示例#5
0
    def test_equivalent_to_vanilla_fed_avg(self):
        # Mime Lite with no-momentum SGD should reduce to FedAvg.
        mime_process = mime.build_weighted_mime_lite(
            model_fn=_create_model, base_optimizer=sgdm.build_sgdm(0.1))
        fed_avg_process = fed_avg.build_weighted_fed_avg(
            model_fn=_create_model, client_optimizer_fn=sgdm.build_sgdm(0.1))

        client_data = [_create_dataset()]
        mime_state = mime_process.initialize()
        fed_avg_state = fed_avg_process.initialize()

        for _ in range(3):
            mime_output = mime_process.next(mime_state, client_data)
            mime_state = mime_output.state
            mime_metrics = mime_output.metrics
            fed_avg_output = fed_avg_process.next(fed_avg_state, client_data)
            fed_avg_state = fed_avg_output.state
            fed_avg_metrics = fed_avg_output.metrics
            self.assertAllClose(
                tf.nest.flatten(mime_process.get_model_weights(mime_state)),
                tf.nest.flatten(
                    fed_avg_process.get_model_weights(fed_avg_state)))
            self.assertAllClose(
                mime_metrics['client_work']['train']['loss'],
                fed_avg_metrics['client_work']['train']['loss'])
            self.assertAllClose(
                mime_metrics['client_work']['train']['num_examples'],
                fed_avg_metrics['client_work']['train']['num_examples'])
示例#6
0
 def test_weighted_fed_avg_with_only_secure_aggregation(self):
     model_fn = model_examples.LinearRegression
     learning_process = fed_avg.build_weighted_fed_avg(
         model_fn,
         client_optimizer_fn=lambda: tf.keras.optimizers.SGD(1.0),
         model_aggregator=model_update_aggregator.secure_aggregator(
             weighted=True),
         metrics_aggregator=aggregator.secure_sum_then_finalize)
     static_assert.assert_not_contains_unsecure_aggregation(
         learning_process.next)
示例#7
0
 def test_weighted_fed_avg_with_only_secure_aggregation(self):
     model = test_models.build_functional_linear_regression()
     learning_process = fed_avg.build_weighted_fed_avg(
         model_fn=None,
         model=model,
         client_optimizer_fn=sgdm.build_sgdm(learning_rate=0.1),
         model_aggregator=model_update_aggregator.secure_aggregator(
             weighted=True),
         metrics_aggregator=aggregator.secure_sum_then_finalize)
     static_assert.assert_not_contains_unsecure_aggregation(
         learning_process.next)
示例#8
0
 def test_raises_on_non_callable_model_fn(self):
     with self.assertRaises(TypeError):
         fed_avg.build_weighted_fed_avg(
             model_fn=model_examples.LinearRegression(),
             client_optimizer_fn=tf.keras.optimizers.SGD)
示例#9
0
 def test_raises_on_model_and_model_fn(self):
     with self.assertRaises(ValueError):
         fed_avg.build_weighted_fed_avg(
             model_fn=model_examples.LinearRegression,
             model=test_models.build_functional_linear_regression(),
             client_optimizer_fn=sgdm.build_sgdm(learning_rate=0.1))
示例#10
0
 def test_raises_on_invalid_client_weighting(self):
     with self.assertRaises(TypeError):
         fed_avg.build_weighted_fed_avg(
             model_fn=model_examples.LinearRegression,
             client_optimizer_fn=sgdm.build_sgdm(1.0),
             client_weighting='uniform')