def test_nan_examples_ignored(self):
        client_data = create_client_data()
        sample_batch = self.evaluate(next(iter(client_data())))

        def model_fn():
            return tff.learning.from_compiled_keras_model(
                tff.simulation.models.mnist.create_simple_keras_model(),
                sample_batch)

        optimizer_fn = lambda: flars_optimizer.FLARSOptimizer(learning_rate=1.0
                                                              )
        model = model_fn()
        optimizer = optimizer_fn()
        state, optimizer_vars = server_init(model, optimizer)

        grad_norm = [1.0, 1.0]
        weights_delta = tf.nest.map_structure(
            lambda t: tf.ones_like(t) * np.inf,
            flars_fedavg._get_weights(model).trainable)

        old_model_vars = self.evaluate(state.model)
        for _ in range(2):
            state = flars_fedavg.server_update(model, optimizer,
                                               optimizer_vars, state,
                                               weights_delta, grad_norm)
        model_vars = self.evaluate(state.model)

        self.assertAllClose(old_model_vars._asdict(), model_vars._asdict())
    def test_self_contained_example(self):
        client_data = create_client_data()
        model = _keras_model_fn()
        outputs = self.evaluate(
            flars_fedavg.client_update(model, client_data(),
                                       flars_fedavg._get_weights(model)))

        self.assertAllEqual(outputs.weights_delta_weight, 2)
        # Expect a grad for each layer:
        #   [Conv, Pool, Conv, Pool, Dense + Bias, Dense + Bias] = 8
        self.assertLen(outputs.optimizer_output['flat_grads_norm_sum'], 8)
        self.assertEqual(outputs.optimizer_output['num_examples'], 2)
Exemple #3
0
def server_init(model, optimizer):
  """Returns initial `tff.learning.framework.ServerState`.

  Args:
    model: A `tff.learning.Model`.
    optimizer: A `tf.keras.optimizer.Optimizer`.

  Returns:
    A `tff.learning.framework.ServerState` namedtuple.
  """
  optimizer_vars = flars_fedavg._create_optimizer_vars(model, optimizer)
  return (flars_fedavg.ServerState(
      model=flars_fedavg._get_weights(model),
      optimizer_state=optimizer_vars), optimizer_vars)
Exemple #4
0
  def test_nan_examples_ignored(self):
    server_optimizer_fn = lambda: flars_optimizer.FLARSOptimizer(1.0)
    model = _keras_model_fn()
    server_optimizer = server_optimizer_fn()
    state, optimizer_vars = server_init(model, server_optimizer)

    grad_norm = [1.0, 1.0]
    weights_delta = tf.nest.map_structure(
        lambda t: tf.ones_like(t) * np.inf,
        flars_fedavg._get_weights(model).trainable)

    old_model_vars = self.evaluate(state.model)
    for _ in range(2):
      state = flars_fedavg.server_update(model, server_optimizer,
                                         optimizer_vars, state, weights_delta,
                                         grad_norm)
    model_vars = self.evaluate(state.model)

    self.assertAllClose(old_model_vars._asdict(), model_vars._asdict())