示例#1
0
def _server_init(model, optimizer):
    """Returns initial `ServerState`.

  Args:
    model: A `tff.learning.Model`.
    optimizer: A `tf.train.Optimizer`.

  Returns:
    A `ServerState` namedtuple.
  """
    simple_fedavg_tff._initialize_optimizer_vars(model, optimizer)
    return simple_fedavg_tf.ServerState(model_weights=model.weights,
                                        optimizer_state=optimizer.variables(),
                                        round_num=0)
示例#2
0
    def test_self_contained_example(self):

        client_data = _create_client_data()

        model = MnistModel()
        optimizer_fn = lambda: tf.keras.optimizers.SGD(learning_rate=0.1)
        losses = []
        for r in range(2):
            optimizer = optimizer_fn()
            simple_fedavg_tff._initialize_optimizer_vars(model, optimizer)
            server_message = simple_fedavg_tf.BroadcastMessage(
                model_weights=model.weights, round_num=r)
            outputs = simple_fedavg_tf.client_update(model, client_data(),
                                                     server_message, optimizer)
            losses.append(outputs.model_output.numpy())

        self.assertAllEqual(int(outputs.client_weight.numpy()), 2)
        self.assertLess(losses[1], losses[0])