def test_fed_avg_with_adaptive_client_and_server(self, optimizer): federated_data = [[_batch_fn()]] client_optimizer_fn = functools.partial(optimizer, epsilon=0.01) server_optimizer_fn = functools.partial(optimizer, epsilon=0.01) iterative_process = fed_avg_client_opt.build_iterative_process( _uncompiled_model_builder, client_optimizer_fn=client_optimizer_fn, server_optimizer_fn=server_optimizer_fn) _, train_outputs = self._run_rounds(iterative_process, federated_data, 5) self.assertLess(train_outputs[4]['loss'], train_outputs[0]['loss'])
def test_fed_avg_decreases_loss(self): federated_data = [[_batch_fn()]] client_optimizer_fn = tf.keras.optimizers.SGD server_optimizer_fn = tf.keras.optimizers.SGD iterative_process = fed_avg_client_opt.build_iterative_process( _uncompiled_model_builder, client_optimizer_fn=client_optimizer_fn, server_optimizer_fn=server_optimizer_fn) _, train_outputs = self._run_rounds(iterative_process, federated_data, 5) self.assertLess(train_outputs[4]['loss'], train_outputs[0]['loss'])
def test_state_types(self): federated_data = [[_batch_fn()]] client_optimizer_fn = functools.partial(tf.keras.optimizers.Adam, epsilon=0.01) server_optimizer_fn = tf.keras.optimizers.SGD iterative_process = fed_avg_client_opt.build_iterative_process( _uncompiled_model_builder, client_optimizer_fn=client_optimizer_fn, server_optimizer_fn=server_optimizer_fn) state, _ = self._run_rounds(iterative_process, federated_data, 1) self.assertIsInstance(state, fed_avg_client_opt.ServerState) self.assertIsInstance(state.model, tff.learning.ModelWeights)
def test_client_state_aggregate_mean(self, n, m): dataset1 = [_batch_fn() for _ in range(n)] dataset2 = [_batch_fn() for _ in range(m)] federated_data = [dataset1, dataset2] client_optimizer_fn = functools.partial(tf.keras.optimizers.Adam, epsilon=0.01) server_optimizer_fn = tf.keras.optimizers.SGD iterative_process = fed_avg_client_opt.build_iterative_process( _uncompiled_model_builder, client_optimizer_fn=client_optimizer_fn, server_optimizer_fn=server_optimizer_fn, client_opt_weight_fn=lambda x: 1.0) state = iterative_process.initialize() state, _ = iterative_process.next(state, federated_data) client_opt_iteration = state.client_optimizer_state.iterations self.assertEqual(client_opt_iteration, (n + m) // 2)
def test_get_model_weights(self): federated_data = [[_batch_fn()]] iterative_process = fed_avg_client_opt.build_iterative_process( _uncompiled_model_builder, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) state = iterative_process.initialize() self.assertIsInstance(iterative_process.get_model_weights(state), tff.learning.ModelWeights) self.assertAllClose( state.model.trainable, iterative_process.get_model_weights(state).trainable) for _ in range(3): state, _ = iterative_process.next(state, federated_data) self.assertIsInstance(iterative_process.get_model_weights(state), tff.learning.ModelWeights) self.assertAllClose( state.model.trainable, iterative_process.get_model_weights(state).trainable)