def _assert_server_update_with_all_ones(self, model_fn): optimizer_fn = lambda: tf.keras.optimizers.SGD(learning_rate=0.1) model = tf.keras.models.Sequential([ tf.keras.layers.Input(shape=(784, )), tf.keras.layers.Flatten(), tf.keras.layers.Dense(units=10, kernel_initializer='zeros'), tf.keras.layers.Softmax(), ]) optimizer = optimizer_fn() state, optimizer_vars = server_init(model, optimizer) weights_delta = tf.nest.map_structure( tf.ones_like, attacked_fedavg._get_weights(model).trainable) for _ in range(2): state = attacked_fedavg.server_update(model, optimizer, optimizer_vars, state, weights_delta, ()) model_vars = self.evaluate(state.model) train_vars = model_vars.trainable # weights are initialized with all-zeros, weights_delta is all ones, # SGD learning rate is 0.1. Updating server for 2 steps. self.assertAllClose( train_vars, tf.nest.map_structure(lambda t: tf.fill(tf.shape(t), 0.2), train_vars))
def test_self_contained_example(self): client_data = create_client_data() model = MnistModel() optimizer = tf.keras.optimizers.SGD(0.1) losses = [] client_update = attacked_fedavg.ClientExplicitBoosting(boost_factor=1.0) for _ in range(2): outputs = client_update(model, optimizer, client_data(), client_data(), tf.constant(False), attacked_fedavg._get_weights(model)) losses.append(outputs.model_output['loss'].numpy()) self.assertAllEqual(outputs.optimizer_output['num_examples'].numpy(), 2) self.assertLess(losses[1], losses[0])
def server_init(model, optimizer, delta_aggregate_state=()): """Returns initial `tff.learning.framework.ServerState`. Args: model: A `tff.learning.Model`. optimizer: A `tf.train.Optimizer`. delta_aggregate_state: A server state. Returns: A `tff.learning.framework.ServerState` namedtuple. """ optimizer_vars = attacked_fedavg._create_optimizer_vars(model, optimizer) return (attacked_fedavg.ServerState( model=attacked_fedavg._get_weights(model), optimizer_state=optimizer_vars, delta_aggregate_state=delta_aggregate_state), optimizer_vars)