def _server_init(model, optimizer): """Returns initial `ServerState`. Args: model: A `tff.learning.Model`. optimizer: A `tf.train.Optimizer`. Returns: A `ServerState` namedtuple. """ simple_fedavg_tff._initialize_optimizer_vars(model, optimizer) return simple_fedavg_tf.ServerState(model_weights=model.weights, optimizer_state=optimizer.variables(), round_num=0)
def test_self_contained_example(self): client_data = _create_client_data() model = MnistModel() optimizer_fn = lambda: tf.keras.optimizers.SGD(learning_rate=0.1) losses = [] for r in range(2): optimizer = optimizer_fn() simple_fedavg_tff._initialize_optimizer_vars(model, optimizer) server_message = simple_fedavg_tf.BroadcastMessage( model_weights=model.weights, round_num=r) outputs = simple_fedavg_tf.client_update(model, client_data(), server_message, optimizer) losses.append(outputs.model_output.numpy()) self.assertAllEqual(int(outputs.client_weight.numpy()), 2) self.assertLess(losses[1], losses[0])