def _assert_server_update_with_all_ones(self, model_fn): optimizer_fn = lambda: tf.keras.optimizers.SGD(learning_rate=0.1) model = tf.keras.models.Sequential([ tf.keras.layers.Input(shape=(784,)), tf.keras.layers.Flatten(), tf.keras.layers.Dense(units=10, kernel_initializer='zeros'), tf.keras.layers.Softmax(), ]) optimizer = optimizer_fn() state, optimizer_vars = server_init(model, optimizer) weights_delta = tf.nest.map_structure( tf.ones_like, attacked_fedavg._get_weights(model).trainable) for _ in range(2): state = attacked_fedavg.server_update(model, optimizer, optimizer_vars, state, weights_delta, ()) model_vars = self.evaluate(state.model) train_vars = model_vars.trainable # weights are initialized with all-zeros, weights_delta is all ones, # SGD learning rate is 0.1. Updating server for 2 steps. values = list(train_vars.values()) self.assertAllClose( values, [np.ones_like(values[0]) * 0.2, np.ones_like(values[1]) * 0.2])
def _assert_server_update_with_all_ones(self, model_fn): optimizer_fn = lambda: tf.keras.optimizers.SGD(learning_rate=0.1) model = model_fn() optimizer = optimizer_fn() state, optimizer_vars = server_init(model, optimizer) weights_delta = tf.nest.map_structure( tf.ones_like, attacked_fedavg._get_weights(model).trainable) for _ in range(2): state = attacked_fedavg.server_update(model, optimizer, optimizer_vars, state, weights_delta, ()) model_vars = self.evaluate(state.model) train_vars = model_vars.trainable self.assertLen(train_vars, 2) # weights are initialized with all-zeros, weights_delta is all ones, # SGD learning rate is 0.1. Updating server for 2 steps. values = list(train_vars.values()) self.assertAllClose( values, [np.ones_like(values[0]) * 0.2, np.ones_like(values[1]) * 0.2])