Beispiel #1
0
 def test_raises_on_invalid_client_weighting(self):
     with self.assertRaises(TypeError):
         mime.build_weighted_mime_lite(
             model_fn=model_examples.LinearRegression,
             base_optimizer=sgdm.build_sgdm(learning_rate=0.01,
                                            momentum=0.9),
             client_weighting='uniform')
Beispiel #2
0
 def test_client_tf_dataset_reduce_fn(self, simulation, mock_method):
     mime.build_weighted_mime_lite(
         model_fn=model_examples.LinearRegression,
         base_optimizer=sgdm.build_sgdm(learning_rate=0.01, momentum=0.9),
         use_experimental_simulation_loop=simulation)
     if simulation:
         mock_method.assert_not_called()
     else:
         mock_method.assert_called()
Beispiel #3
0
 def test_construction_calls_model_fn(self):
     # Assert that the process building does not call `model_fn` too many times.
     # `model_fn` can potentially be expensive (loading weights, processing, etc
     # ).
     mock_model_fn = mock.Mock(side_effect=model_examples.LinearRegression)
     mime.build_weighted_mime_lite(model_fn=mock_model_fn,
                                   base_optimizer=sgdm.build_sgdm(
                                       learning_rate=0.01, momentum=0.9))
     self.assertEqual(mock_model_fn.call_count, 3)
Beispiel #4
0
 def test_raises_on_invalid_distributor(self):
     model_weights_type = type_conversions.type_from_tensors(
         model_utils.ModelWeights.from_model(
             model_examples.LinearRegression()))
     distributor = distributors.build_broadcast_process(model_weights_type)
     invalid_distributor = iterative_process.IterativeProcess(
         distributor.initialize, distributor.next)
     with self.assertRaises(TypeError):
         mime.build_weighted_mime_lite(
             model_fn=model_examples.LinearRegression,
             base_optimizer=sgdm.build_sgdm(learning_rate=0.01,
                                            momentum=0.9),
             model_distributor=invalid_distributor)
Beispiel #5
0
 def test_weighted_mime_lite_raises_on_unweighted_aggregator(self):
     aggregator = model_update_aggregator.robust_aggregator(weighted=False)
     with self.assertRaisesRegex(TypeError, 'WeightedAggregationFactory'):
         mime.build_weighted_mime_lite(
             model_fn=model_examples.LinearRegression,
             base_optimizer=sgdm.build_sgdm(learning_rate=0.01,
                                            momentum=0.9),
             model_aggregator=aggregator)
     with self.assertRaisesRegex(TypeError, 'WeightedAggregationFactory'):
         mime.build_weighted_mime_lite(
             model_fn=model_examples.LinearRegression,
             base_optimizer=sgdm.build_sgdm(learning_rate=0.01,
                                            momentum=0.9),
             full_gradient_aggregator=aggregator)
Beispiel #6
0
    def test_equivalent_to_vanilla_fed_avg(self):
        # Mime Lite with no-momentum SGD should reduce to FedAvg.
        mime_process = mime.build_weighted_mime_lite(
            model_fn=_create_model, base_optimizer=sgdm.build_sgdm(0.1))
        fed_avg_process = fed_avg.build_weighted_fed_avg(
            model_fn=_create_model, client_optimizer_fn=sgdm.build_sgdm(0.1))

        client_data = [_create_dataset()]
        mime_state = mime_process.initialize()
        fed_avg_state = fed_avg_process.initialize()

        for _ in range(3):
            mime_output = mime_process.next(mime_state, client_data)
            mime_state = mime_output.state
            mime_metrics = mime_output.metrics
            fed_avg_output = fed_avg_process.next(fed_avg_state, client_data)
            fed_avg_state = fed_avg_output.state
            fed_avg_metrics = fed_avg_output.metrics
            self.assertAllClose(
                tf.nest.flatten(mime_process.get_model_weights(mime_state)),
                tf.nest.flatten(
                    fed_avg_process.get_model_weights(fed_avg_state)))
            self.assertAllClose(
                mime_metrics['client_work']['train']['loss'],
                fed_avg_metrics['client_work']['train']['loss'])
            self.assertAllClose(
                mime_metrics['client_work']['train']['num_examples'],
                fed_avg_metrics['client_work']['train']['num_examples'])
Beispiel #7
0
 def test_weighted_mime_lite_with_only_secure_aggregation(self):
     aggregator = model_update_aggregator.secure_aggregator(weighted=True)
     learning_process = mime.build_weighted_mime_lite(
         model_examples.LinearRegression,
         base_optimizer=sgdm.build_sgdm(learning_rate=0.01, momentum=0.9),
         model_aggregator=aggregator,
         full_gradient_aggregator=aggregator,
         metrics_aggregator=metrics_aggregator.secure_sum_then_finalize)
     static_assert.assert_not_contains_unsecure_aggregation(
         learning_process.next)
Beispiel #8
0
    def test_execution_with_optimizers(self, base_optimizer, server_optimizer):
        learning_process = mime.build_weighted_mime_lite(
            model_fn=_create_model,
            base_optimizer=base_optimizer,
            server_optimizer=server_optimizer)

        client_data = [_create_dataset()]
        state = learning_process.initialize()

        for _ in range(3):
            output = learning_process.next(state, client_data)
            state = output.state
            metrics = output.metrics
            self.assertEqual(8,
                             metrics['client_work']['train']['num_examples'])
Beispiel #9
0
 def test_raises_on_non_callable_model_fn(self):
     with self.assertRaises(TypeError):
         mime.build_weighted_mime_lite(
             model_fn=model_examples.LinearRegression(),
             base_optimizer=sgdm.build_sgdm(learning_rate=0.01,
                                            momentum=0.9))