def test_client_tf(self, simulation): model = self.model() dataset = self.dataset() client_tf = federated_sgd.ClientSgd( model, use_experimental_simulation_loop=simulation) client_outputs = self.evaluate( client_tf(dataset, self.initial_weights())) # Both trainable parameters should have gradients, and we don't return the # non-trainable 'c'. Model deltas for squared error: self.assertAllClose(client_outputs.weights_delta, [[[1.0], [0.0]], 1.0]) self.assertAllClose(client_outputs.weights_delta_weight, 8.0) self.assertEqual( client_outputs.model_output, { 'num_examples': 8, 'num_examples_float': 8.0, 'num_batches': 3, 'loss': 0.5, }) self.assertEqual(client_outputs.optimizer_output, { 'client_weight': 8.0, 'has_non_finite_delta': 0, })
def test_client_tf(self): model = self.model() dataset = self.dataset() client_tf = federated_sgd.ClientSgd(model) init_op = tf.group(model_utils.model_initializer(model), tf.compat.v1.initializers.variables( client_tf.variables), name='fedsgd_initializer') client_outputs = client_tf(dataset, self.initial_weights()) tf.compat.v1.get_default_graph().finalize() with self.session() as sess: sess.run(init_op) out = sess.run(client_outputs) # Both trainable parameters should have gradients, # and we don't return the non-trainable 'c'. self.assertCountEqual(['a', 'b'], list(out.weights_delta.keys())) # Model deltas for squared error. self.assertAllClose(out.weights_delta['a'], [[1.0], [0.0]]) self.assertAllClose(out.weights_delta['b'], 1.0) self.assertAllClose(out.weights_delta_weight, 8.0) self.assertEqual(out.model_output['num_examples'], 8) self.assertEqual(out.model_output['num_batches'], 3) self.assertAlmostEqual(out.model_output['loss'], 0.5) self.assertEqual(out.optimizer_output['client_weight'], 8.0) self.assertEqual(out.optimizer_output['has_non_finite_delta'], 0)
def test_client_tf(self, simulation, weighted): model = self.model() dataset = self.dataset() if weighted: client_weighting = client_weight_lib.ClientWeighting.NUM_EXAMPLES else: client_weighting = client_weight_lib.ClientWeighting.UNIFORM client_tf = federated_sgd.ClientSgd( model, client_weighting=client_weighting, use_experimental_simulation_loop=simulation) client_outputs = self.evaluate(client_tf(dataset, self.initial_weights())) # Both trainable parameters should have gradients, and we don't return the # non-trainable 'c'. Model deltas for squared error: self.assertAllClose(client_outputs.weights_delta, [[[1.0], [0.0]], 1.0]) if weighted: self.assertAllClose(client_outputs.weights_delta_weight, 8.0) else: self.assertAllClose(client_outputs.weights_delta_weight, 1.0) self.assertDictContainsSubset( client_outputs.model_output, { 'num_examples': 8, 'num_examples_float': 8.0, 'num_batches': 3, 'loss': 0.5, }) self.assertEqual(client_outputs.optimizer_output['has_non_finite_delta'], 0)
def test_client_tf_custom_batch_weight(self): model = self.model() dataset = self.dataset() client_tf = federated_sgd.ClientSgd( model, batch_weight_fn=lambda batch: 2.0 * tf.reduce_sum(batch.x)) out = client_tf(dataset, self.initial_weights()) self.assertEqual(out.weights_delta_weight.numpy(), 16.0) # 2 * 8
def test_client_tf(self): model = self.model() dataset = self.dataset() client_tf = federated_sgd.ClientSgd(model) client_outputs = self.evaluate( client_tf(dataset, self.initial_weights())) # Both trainable parameters should have gradients, # and we don't return the non-trainable 'c'. self.assertCountEqual(['a', 'b'], client_outputs.weights_delta.keys()) # Model deltas for squared error. self.assertAllClose(client_outputs.weights_delta['a'], [[1.0], [0.0]]) self.assertAllClose(client_outputs.weights_delta['b'], 1.0) self.assertAllClose(client_outputs.weights_delta_weight, 8.0) self.assertEqual( client_outputs.model_output, { 'num_examples': 8, 'num_examples_float': 8.0, 'num_batches': 3, 'loss': 0.5, }) self.assertEqual(client_outputs.optimizer_output, { 'client_weight': 8.0, 'has_non_finite_delta': 0, })
def test_client_tf_dataset_reduce_fn(self, simulation, mock_method): model = self.model() dataset = self.dataset() client_tf = federated_sgd.ClientSgd( model, use_experimental_simulation_loop=simulation) client_tf(dataset, self.initial_weights()) if simulation: mock_method.assert_not_called() else: mock_method.assert_called()
def test_non_finite_aggregation(self, bad_value): model = self.model() dataset = self.dataset() client_tf = federated_sgd.ClientSgd(model) init_weights = self.initial_weights() init_weights.trainable[1] = bad_value client_outputs = client_tf(dataset, init_weights) self.assertEqual(self.evaluate(client_outputs.weights_delta_weight), 0.0) self.assertAllClose( self.evaluate(client_outputs.weights_delta), [[[0.0], [0.0]], 0.0])
def test_non_finite_aggregation(self, bad_value): model = self.model() dataset = self.dataset() client_tf = federated_sgd.ClientSgd(model) init_weights = self.initial_weights() init_weights.trainable['b'] = bad_value out = client_tf(dataset, init_weights) self.assertEqual(out.weights_delta_weight.numpy(), 0.0) self.assertAllClose(out.weights_delta['a'].numpy(), np.array([[0.0], [0.0]])) self.assertAllClose(out.weights_delta['b'].numpy(), 0.0) self.assertEqual(out.optimizer_output['has_non_finite_delta'].numpy(), 1)
def test_client_tf(self): model = self.model() dataset = self.dataset() client_tf = federated_sgd.ClientSgd(model) out = client_tf(dataset, self.initial_weights()) out = nest.map_structure(lambda t: t.numpy(), out) # Both trainable parameters should have gradients, # and we don't return the non-trainable 'c'. self.assertCountEqual(['a', 'b'], out.weights_delta.keys()) # Model deltas for squared error. self.assertAllClose(out.weights_delta['a'], [[1.0], [0.0]]) self.assertAllClose(out.weights_delta['b'], 1.0) self.assertAllClose(out.weights_delta_weight, 8.0) self.assertEqual(out.model_output['num_examples'], 8) self.assertEqual(out.model_output['num_batches'], 3) self.assertAlmostEqual(out.model_output['loss'], 0.5) self.assertEqual(out.optimizer_output['client_weight'], 8.0) self.assertEqual(out.optimizer_output['has_non_finite_delta'], 0)
def test_clietsgd_fails_for_non_tff_model(self): keras_model = tf.keras.Sequential([tf.keras.layers.Dense(1)]) with self.assertRaisesRegex(TypeError, 'Model'): federated_sgd.ClientSgd(keras_model)