Ejemplo n.º 1
0
 def test_client_tf_custom_delta_weight(self):
     model = self.create_model()
     dataset = self.create_dataset()
     client_tf = federated_averaging._DeprecatedClientFedAvg(
         model, client_weight_fn=lambda _: tf.constant(1.5))
     client_outputs = client_tf(dataset, self.initial_weights())
     self.assertEqual(self.evaluate(client_outputs.weights_delta_weight),
                      1.5)
Ejemplo n.º 2
0
 def test_non_finite_aggregation(self, bad_value):
     model = self.create_model()
     dataset = self.create_dataset()
     client_tf = federated_averaging._DeprecatedClientFedAvg(model)
     init_weights = self.initial_weights()
     init_weights.trainable[1] = bad_value
     client_outputs = client_tf(dataset, init_weights)
     self.assertEqual(self.evaluate(client_outputs.weights_delta_weight),
                      0.0)
     self.assertAllClose(self.evaluate(client_outputs.weights_delta),
                         [[[0.0], [0.0]], 0.0])
     self.assertEqual(
         self.evaluate(
             client_outputs.optimizer_output['has_non_finite_delta']), 1)
Ejemplo n.º 3
0
    def test_client_tf(self):
        model = self.create_model()
        dataset = self.create_dataset()
        client_tf = federated_averaging._DeprecatedClientFedAvg(model)
        client_outputs = self.evaluate(
            client_tf(dataset, self.initial_weights()))

        # Both trainable parameters should have been updated,
        # and we don't return the non-trainable variable.
        self.assertAllGreater(
            np.linalg.norm(client_outputs.weights_delta, axis=-1), 0.1)
        self.assertEqual(client_outputs.weights_delta_weight, 8.0)
        self.assertEqual(client_outputs.optimizer_output['num_examples'], 8)
        self.assertEqual(
            client_outputs.optimizer_output['has_non_finite_delta'], 0)

        self.assertDictContainsSubset(
            {
                'num_examples': 8,
                'num_examples_float': 8.0,
                'num_batches': 3,
            }, client_outputs.model_output)
        self.assertBetween(client_outputs.model_output['loss'],
                           np.finfo(np.float32).eps, 10.0)