def testComputeLRMaxRatio(self, dtype): shape = [3, 3] var_np = np.ones(shape) grad_np = np.ones(shape) * 0.0001 var = tf.Variable(var_np, dtype=dtype, name='a') grad = tf.Variable(grad_np, dtype=dtype) base_lr = 1.0 opt = flars_optimizer.FLARSOptimizer(base_lr) scaled_lr = opt.compute_lr(base_lr, grad, var, tf.norm(grad)) self.assertAlmostEqual(base_lr, scaled_lr)
def build_federated_averaging_process( model_fn, client_optimizer_fn, server_optimizer_fn=lambda: flars_optimizer.FLARSOptimizer(learning_rate= 1.0)): """Builds the TFF computations for optimization using federated averaging. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. client_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer` for the local client training. server_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer` for applying updates on the server. Returns: A `tff.templates.IterativeProcess`. """ with tf.Graph().as_default(): dummy_model_for_metadata = model_fn() type_signature_grads_norm = tuple( weight.dtype for weight in tf.nest.flatten( dummy_model_for_metadata.trainable_variables)) server_init_tf = build_server_init_fn(model_fn, server_optimizer_fn) server_state_type = server_init_tf.type_signature.result server_update_fn = build_server_update_fn(model_fn, server_optimizer_fn, server_state_type, server_state_type.model, type_signature_grads_norm) tf_dataset_type = tff.SequenceType(dummy_model_for_metadata.input_spec) client_update_fn = build_client_update_fn(model_fn, client_optimizer_fn, tf_dataset_type, server_state_type.model) federated_server_state_type = tff.type_at_server(server_state_type) federated_dataset_type = tff.type_at_clients(tf_dataset_type) run_one_round_tff = build_run_one_round_fn(server_update_fn, client_update_fn, dummy_model_for_metadata, federated_server_state_type, federated_dataset_type) return tff.templates.IterativeProcess( initialize_fn=tff.federated_computation( lambda: tff.federated_eval(server_init_tf, tff.SERVER)), next_fn=run_one_round_tff)
def testFLARSGradientOneStep(self, dtype, momentum): shape = [3, 3] var_np = np.ones(shape) grad_np = np.ones(shape) lr_np = 0.1 m_np = momentum ep_np = 1e-5 eeta = 0.1 vel_np = np.zeros(shape) var = tf.Variable(var_np, dtype=dtype, name='a') grad = tf.Variable(grad_np, dtype=dtype) opt = flars_optimizer.FLARSOptimizer( learning_rate=lr_np, momentum=m_np, eeta=eeta, epsilon=ep_np) g_norm = np.linalg.norm(grad_np.flatten(), ord=2) opt.update_grads_norm([var], [g_norm]) self.evaluate(tf.compat.v1.global_variables_initializer()) pre_var = self.evaluate(var) self.assertAllClose(var_np, pre_var) opt.apply_gradients([(grad, var)]) post_var = self.evaluate(var) w_norm = np.linalg.norm(var_np.flatten(), ord=2) trust_ratio = eeta * w_norm / (g_norm + ep_np) scaled_lr = lr_np * trust_ratio vel_np = m_np * vel_np - scaled_lr * grad_np var_np += vel_np self.assertAllClose(var_np, post_var) if m_np != 0: post_vel = self.evaluate(opt.get_slot(var, 'momentum')) self.assertAllClose(vel_np, post_vel)
def test_nan_examples_ignored(self): server_optimizer_fn = lambda: flars_optimizer.FLARSOptimizer(1.0) model = _keras_model_fn() server_optimizer = server_optimizer_fn() state, optimizer_vars = server_init(model, server_optimizer) grad_norm = [1.0, 1.0] weights_delta = tf.nest.map_structure( lambda t: tf.ones_like(t) * float('inf'), flars_fedavg.tff.learning.framework.ModelWeights.from_model( model).trainable) old_model_vars = state.model for _ in range(2): state = flars_fedavg.server_update(model, server_optimizer, optimizer_vars, state, weights_delta, grad_norm) model_vars = state.model # Assert the model hasn't changed. self.assertAllClose(old_model_vars.trainable, model_vars.trainable) self.assertAllClose(old_model_vars.non_trainable, model_vars.non_trainable)