def test_self_contained_example(self): client_data = create_client_data() model = MnistTrainableModel() losses = [] for _ in range(2): outputs = simple_fedavg.client_update(model, client_data(), simple_fedavg._get_weights(model)) losses.append(outputs.model_output['loss'].numpy()) self.assertAllEqual(outputs.optimizer_output['num_examples'].numpy(), 2) self.assertLess(losses[1], losses[0])
def server_init(model, optimizer): """Returns initial `tff.learning.framework.ServerState`. Args: model: A `tff.learning.Model`. optimizer: A `tf.train.Optimizer`. Returns: A `tff.learning.framework.ServerState` namedtuple. """ optimizer_vars = simple_fedavg._create_optimizer_vars(model, optimizer) return (simple_fedavg.ServerState( model=simple_fedavg._get_weights(model), optimizer_state=optimizer_vars), optimizer_vars)
def _assert_server_update_with_all_ones(self, model_fn): optimizer_fn = lambda: tf.keras.optimizers.SGD(learning_rate=0.1) model = model_fn() optimizer = optimizer_fn() state, optimizer_vars = server_init(model, optimizer) weights_delta = tf.nest.map_structure( tf.ones_like, simple_fedavg._get_weights(model).trainable) for _ in range(2): state = simple_fedavg.server_update(model, optimizer, optimizer_vars, state, weights_delta) model_vars = self.evaluate(state.model) train_vars = model_vars.trainable self.assertLen(train_vars, 2) # weights are initialized with all-zeros, weights_delta is all ones, # SGD learning rate is 0.1. Updating server for 2 steps. self.assertAllClose( train_vars, {k: np.ones_like(v) * 0.2 for k, v in train_vars.items()})