def test_nan_examples_ignored(self): client_data = create_client_data() sample_batch = self.evaluate(next(iter(client_data()))) def model_fn(): return tff.learning.from_compiled_keras_model( tff.simulation.models.mnist.create_simple_keras_model(), sample_batch) optimizer_fn = lambda: flars_optimizer.FLARSOptimizer(learning_rate=1.0 ) model = model_fn() optimizer = optimizer_fn() state, optimizer_vars = server_init(model, optimizer) grad_norm = [1.0, 1.0] weights_delta = tf.nest.map_structure( lambda t: tf.ones_like(t) * np.inf, flars_fedavg._get_weights(model).trainable) old_model_vars = self.evaluate(state.model) for _ in range(2): state = flars_fedavg.server_update(model, optimizer, optimizer_vars, state, weights_delta, grad_norm) model_vars = self.evaluate(state.model) self.assertAllClose(old_model_vars._asdict(), model_vars._asdict())
def test_self_contained_example(self): client_data = create_client_data() model = _keras_model_fn() outputs = self.evaluate( flars_fedavg.client_update(model, client_data(), flars_fedavg._get_weights(model))) self.assertAllEqual(outputs.weights_delta_weight, 2) # Expect a grad for each layer: # [Conv, Pool, Conv, Pool, Dense + Bias, Dense + Bias] = 8 self.assertLen(outputs.optimizer_output['flat_grads_norm_sum'], 8) self.assertEqual(outputs.optimizer_output['num_examples'], 2)
def server_init(model, optimizer): """Returns initial `tff.learning.framework.ServerState`. Args: model: A `tff.learning.Model`. optimizer: A `tf.keras.optimizer.Optimizer`. Returns: A `tff.learning.framework.ServerState` namedtuple. """ optimizer_vars = flars_fedavg._create_optimizer_vars(model, optimizer) return (flars_fedavg.ServerState( model=flars_fedavg._get_weights(model), optimizer_state=optimizer_vars), optimizer_vars)
def test_nan_examples_ignored(self): server_optimizer_fn = lambda: flars_optimizer.FLARSOptimizer(1.0) model = _keras_model_fn() server_optimizer = server_optimizer_fn() state, optimizer_vars = server_init(model, server_optimizer) grad_norm = [1.0, 1.0] weights_delta = tf.nest.map_structure( lambda t: tf.ones_like(t) * np.inf, flars_fedavg._get_weights(model).trainable) old_model_vars = self.evaluate(state.model) for _ in range(2): state = flars_fedavg.server_update(model, server_optimizer, optimizer_vars, state, weights_delta, grad_norm) model_vars = self.evaluate(state.model) self.assertAllClose(old_model_vars._asdict(), model_vars._asdict())