Esempio n. 1
0
  def test_self_contained_example(self):

    client_data = create_client_data()

    model = MnistTrainableModel()
    losses = []
    for _ in range(2):
      outputs = simple_fedavg.client_update(model, client_data(),
                                            simple_fedavg._get_weights(model))
      losses.append(outputs.model_output['loss'].numpy())

    self.assertAllEqual(outputs.optimizer_output['num_examples'].numpy(), 2)
    self.assertLess(losses[1], losses[0])
Esempio n. 2
0
def server_init(model, optimizer):
    """Returns initial `tff.learning.framework.ServerState`.

  Args:
    model: A `tff.learning.Model`.
    optimizer: A `tf.train.Optimizer`.

  Returns:
    A `tff.learning.framework.ServerState` namedtuple.
  """
    optimizer_vars = simple_fedavg._create_optimizer_vars(model, optimizer)
    return (simple_fedavg.ServerState(model=simple_fedavg._get_weights(model),
                                      optimizer_state=optimizer_vars),
            optimizer_vars)
Esempio n. 3
0
  def _assert_server_update_with_all_ones(self, model_fn):
    optimizer_fn = lambda: tf.keras.optimizers.SGD(learning_rate=0.1)
    model = model_fn()
    optimizer = optimizer_fn()
    state, optimizer_vars = server_init(model, optimizer)
    weights_delta = tf.nest.map_structure(
        tf.ones_like,
        simple_fedavg._get_weights(model).trainable)

    for _ in range(2):
      state = simple_fedavg.server_update(model, optimizer, optimizer_vars,
                                          state, weights_delta)

    model_vars = self.evaluate(state.model)
    train_vars = model_vars.trainable
    self.assertLen(train_vars, 2)
    # weights are initialized with all-zeros, weights_delta is all ones,
    # SGD learning rate is 0.1. Updating server for 2 steps.
    values = list(train_vars.values())
    self.assertAllClose(
        values, [np.ones_like(values[0]) * 0.2,
                 np.ones_like(values[1]) * 0.2])