def _assert_server_update_with_all_ones(self, model_fn): optimizer_fn = lambda: tf.keras.optimizers.SGD(learning_rate=0.1) model = model_fn() optimizer = optimizer_fn() state = _server_init(model, optimizer) weights_delta = tf.nest.map_structure(tf.ones_like, model.trainable_variables) for _ in range(2): state = simple_fedavg_tf.server_update(model, optimizer, state, weights_delta) model_vars = self.evaluate(state.model_weights) train_vars = model_vars.trainable self.assertLen(train_vars, 2) self.assertEqual(state.round_num, 2) # weights are initialized with all-zeros, weights_delta is all ones, # SGD learning rate is 0.1. Updating server for 2 steps. self.assertAllClose(train_vars, [np.ones_like(v) * 0.2 for v in train_vars])
def server_update_fn(server_state, model_delta): model = model_fn() server_optimizer = server_optimizer_fn() _initialize_optimizer_vars(model, server_optimizer) return server_update(model, server_optimizer, server_state, model_delta)