def test_client_tf(self):
        model = self.create_model()
        dataset = self.create_dataset()
        client_tf = federated_averaging._ClientFedAvg(
            model, tf.keras.optimizers.SGD(learning_rate=0.1))
        client_outputs = self.evaluate(
            client_tf(dataset, self.initial_weights()))

        # Both trainable parameters should have been updated,
        # and we don't return the non-trainable variable.
        self.assertAllGreater(
            np.linalg.norm(client_outputs.weights_delta, axis=-1), 0.1)
        self.assertEqual(client_outputs.weights_delta_weight, 8.0)
        self.assertEqual(client_outputs.optimizer_output['num_examples'], 8)
        self.assertEqual(
            client_outputs.optimizer_output['has_non_finite_delta'], 0)

        self.assertDictContainsSubset(
            {
                'num_examples': 8,
                'num_examples_float': 8.0,
                'num_batches': 3,
            }, client_outputs.model_output)
        self.assertBetween(client_outputs.model_output['loss'],
                           np.finfo(np.float32).eps, 10.0)
Esempio n. 2
0
 def test_client_tf_custom_delta_weight(self):
   model = self.create_model()
   dataset = self.create_dataset()
   client_tf = federated_averaging._ClientFedAvg(
       model,
       tf.keras.optimizers.SGD(learning_rate=0.1),
       client_weight_fn=lambda _: tf.constant(1.5))
   client_outputs = client_tf(dataset, self.initial_weights())
   self.assertEqual(self.evaluate(client_outputs.weights_delta_weight), 1.5)
Esempio n. 3
0
 def test_non_finite_aggregation(self, bad_value):
   model = self.create_model()
   dataset = self.create_dataset()
   client_tf = federated_averaging._ClientFedAvg(
       model, tf.keras.optimizers.SGD(learning_rate=0.1))
   init_weights = self.initial_weights()
   init_weights.trainable[1] = bad_value
   client_outputs = client_tf(dataset, init_weights)
   self.assertEqual(self.evaluate(client_outputs.weights_delta_weight), 0.0)
   self.assertAllClose(
       self.evaluate(client_outputs.weights_delta), [[[0.0], [0.0]], 0.0])
   self.assertEqual(
       self.evaluate(client_outputs.optimizer_output['has_non_finite_delta']),
       1)