Ejemplo n.º 1
0
  def testComputeLRMaxRatio(self, dtype):
    shape = [3, 3]
    var_np = np.ones(shape)
    grad_np = np.ones(shape) * 0.0001

    var = tf.Variable(var_np, dtype=dtype, name='a')
    grad = tf.Variable(grad_np, dtype=dtype)

    base_lr = 1.0
    opt = flars_optimizer.FLARSOptimizer(base_lr)
    scaled_lr = opt.compute_lr(base_lr, grad, var, tf.norm(grad))
    self.assertAlmostEqual(base_lr, scaled_lr)
Ejemplo n.º 2
0
def build_federated_averaging_process(
    model_fn,
    client_optimizer_fn,
    server_optimizer_fn=lambda: flars_optimizer.FLARSOptimizer(learning_rate=
                                                               1.0)):
    """Builds the TFF computations for optimization using federated averaging.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    client_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer` for the local client training.
    server_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer` for applying updates on the server.

  Returns:
    A `tff.templates.IterativeProcess`.
  """
    with tf.Graph().as_default():
        dummy_model_for_metadata = model_fn()
    type_signature_grads_norm = tuple(
        weight.dtype for weight in tf.nest.flatten(
            dummy_model_for_metadata.trainable_variables))

    server_init_tf = build_server_init_fn(model_fn, server_optimizer_fn)

    server_state_type = server_init_tf.type_signature.result
    server_update_fn = build_server_update_fn(model_fn, server_optimizer_fn,
                                              server_state_type,
                                              server_state_type.model,
                                              type_signature_grads_norm)

    tf_dataset_type = tff.SequenceType(dummy_model_for_metadata.input_spec)
    client_update_fn = build_client_update_fn(model_fn, client_optimizer_fn,
                                              tf_dataset_type,
                                              server_state_type.model)

    federated_server_state_type = tff.type_at_server(server_state_type)
    federated_dataset_type = tff.type_at_clients(tf_dataset_type)
    run_one_round_tff = build_run_one_round_fn(server_update_fn,
                                               client_update_fn,
                                               dummy_model_for_metadata,
                                               federated_server_state_type,
                                               federated_dataset_type)

    return tff.templates.IterativeProcess(
        initialize_fn=tff.federated_computation(
            lambda: tff.federated_eval(server_init_tf, tff.SERVER)),
        next_fn=run_one_round_tff)
Ejemplo n.º 3
0
  def testFLARSGradientOneStep(self, dtype, momentum):
    shape = [3, 3]
    var_np = np.ones(shape)
    grad_np = np.ones(shape)
    lr_np = 0.1
    m_np = momentum
    ep_np = 1e-5
    eeta = 0.1
    vel_np = np.zeros(shape)

    var = tf.Variable(var_np, dtype=dtype, name='a')
    grad = tf.Variable(grad_np, dtype=dtype)

    opt = flars_optimizer.FLARSOptimizer(
        learning_rate=lr_np, momentum=m_np, eeta=eeta, epsilon=ep_np)

    g_norm = np.linalg.norm(grad_np.flatten(), ord=2)
    opt.update_grads_norm([var], [g_norm])

    self.evaluate(tf.compat.v1.global_variables_initializer())

    pre_var = self.evaluate(var)

    self.assertAllClose(var_np, pre_var)

    opt.apply_gradients([(grad, var)])

    post_var = self.evaluate(var)

    w_norm = np.linalg.norm(var_np.flatten(), ord=2)
    trust_ratio = eeta * w_norm / (g_norm + ep_np)
    scaled_lr = lr_np * trust_ratio

    vel_np = m_np * vel_np - scaled_lr * grad_np
    var_np += vel_np

    self.assertAllClose(var_np, post_var)
    if m_np != 0:
      post_vel = self.evaluate(opt.get_slot(var, 'momentum'))
      self.assertAllClose(vel_np, post_vel)
    def test_nan_examples_ignored(self):
        server_optimizer_fn = lambda: flars_optimizer.FLARSOptimizer(1.0)
        model = _keras_model_fn()
        server_optimizer = server_optimizer_fn()
        state, optimizer_vars = server_init(model, server_optimizer)

        grad_norm = [1.0, 1.0]
        weights_delta = tf.nest.map_structure(
            lambda t: tf.ones_like(t) * float('inf'),
            flars_fedavg.tff.learning.framework.ModelWeights.from_model(
                model).trainable)

        old_model_vars = state.model
        for _ in range(2):
            state = flars_fedavg.server_update(model, server_optimizer,
                                               optimizer_vars, state,
                                               weights_delta, grad_norm)
        model_vars = state.model
        # Assert the model hasn't changed.
        self.assertAllClose(old_model_vars.trainable, model_vars.trainable)
        self.assertAllClose(old_model_vars.non_trainable,
                            model_vars.non_trainable)