def test_client_tf_custom_delta_weight(self): model = self.create_model() dataset = self.create_dataset() client_tf = federated_averaging._DeprecatedClientFedAvg( model, client_weight_fn=lambda _: tf.constant(1.5)) client_outputs = client_tf(dataset, self.initial_weights()) self.assertEqual(self.evaluate(client_outputs.weights_delta_weight), 1.5)
def test_non_finite_aggregation(self, bad_value): model = self.create_model() dataset = self.create_dataset() client_tf = federated_averaging._DeprecatedClientFedAvg(model) init_weights = self.initial_weights() init_weights.trainable[1] = bad_value client_outputs = client_tf(dataset, init_weights) self.assertEqual(self.evaluate(client_outputs.weights_delta_weight), 0.0) self.assertAllClose(self.evaluate(client_outputs.weights_delta), [[[0.0], [0.0]], 0.0]) self.assertEqual( self.evaluate( client_outputs.optimizer_output['has_non_finite_delta']), 1)
def test_client_tf(self): model = self.create_model() dataset = self.create_dataset() client_tf = federated_averaging._DeprecatedClientFedAvg(model) client_outputs = self.evaluate( client_tf(dataset, self.initial_weights())) # Both trainable parameters should have been updated, # and we don't return the non-trainable variable. self.assertAllGreater( np.linalg.norm(client_outputs.weights_delta, axis=-1), 0.1) self.assertEqual(client_outputs.weights_delta_weight, 8.0) self.assertEqual(client_outputs.optimizer_output['num_examples'], 8) self.assertEqual( client_outputs.optimizer_output['has_non_finite_delta'], 0) self.assertDictContainsSubset( { 'num_examples': 8, 'num_examples_float': 8.0, 'num_batches': 3, }, client_outputs.model_output) self.assertBetween(client_outputs.model_output['loss'], np.finfo(np.float32).eps, 10.0)