Esempio n. 1
0
    def test_nan_examples_ignored(self):
        client_data = create_client_data()
        sample_batch = self.evaluate(next(iter(client_data())))

        def model_fn():
            return tff.learning.from_compiled_keras_model(
                tff.simulation.models.mnist.create_simple_keras_model(),
                sample_batch)

        optimizer_fn = lambda: flars_optimizer.FLARSOptimizer(learning_rate=1.0
                                                              )
        model = model_fn()
        optimizer = optimizer_fn()
        state, optimizer_vars = server_init(model, optimizer)

        grad_norm = [1.0, 1.0]
        weights_delta = tf.nest.map_structure(
            lambda t: tf.ones_like(t) * np.inf,
            flars_fedavg._get_weights(model).trainable)

        old_model_vars = self.evaluate(state.model)
        for _ in range(2):
            state = flars_fedavg.server_update(model, optimizer,
                                               optimizer_vars, state,
                                               weights_delta, grad_norm)
        model_vars = self.evaluate(state.model)

        self.assertAllClose(old_model_vars._asdict(), model_vars._asdict())
    def testComputeLRMaxRatio(self, dtype):
        shape = [3, 3]
        var_np = np.ones(shape)
        grad_np = np.ones(shape) * 0.0001

        var = tf.Variable(var_np, dtype=dtype, name='a')
        grad = tf.Variable(grad_np, dtype=dtype)

        base_lr = 1.0
        opt = flars_optimizer.FLARSOptimizer(base_lr)
        scaled_lr = opt.compute_lr(base_lr, grad, var, tf.norm(grad))
        self.assertAlmostEqual(base_lr, scaled_lr)
Esempio n. 3
0
def build_federated_averaging_process(
    model_fn,
    client_optimizer_fn,
    server_optimizer_fn=lambda: flars_optimizer.FLARSOptimizer(learning_rate=
                                                               1.0)):
    """Builds the TFF computations for optimization using federated averaging.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    client_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer` for the local client training.
    server_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer` for applying updates on the server.

  Returns:
    A `tff.templates.IterativeProcess`.
  """
    with tf.Graph().as_default():
        dummy_model_for_metadata = model_fn()
    type_signature_grads_norm = tuple(
        weight.dtype for weight in tf.nest.flatten(
            dummy_model_for_metadata.trainable_variables))

    server_init_tf = build_server_init_fn(model_fn, server_optimizer_fn)

    server_state_type = server_init_tf.type_signature.result
    server_update_fn = build_server_update_fn(model_fn, server_optimizer_fn,
                                              server_state_type,
                                              server_state_type.model,
                                              type_signature_grads_norm)

    tf_dataset_type = tff.SequenceType(dummy_model_for_metadata.input_spec)
    client_update_fn = build_client_update_fn(model_fn, client_optimizer_fn,
                                              tf_dataset_type,
                                              server_state_type.model)

    federated_server_state_type = tff.FederatedType(server_state_type,
                                                    tff.SERVER)
    federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)
    run_one_round_tff = build_run_one_round_fn(server_update_fn,
                                               client_update_fn,
                                               dummy_model_for_metadata,
                                               federated_server_state_type,
                                               federated_dataset_type)

    return tff.templates.IterativeProcess(
        initialize_fn=tff.federated_computation(
            lambda: tff.federated_eval(server_init_tf, tff.SERVER)),
        next_fn=run_one_round_tff)
Esempio n. 4
0
def build_federated_averaging_process(
    model_fn,
    server_optimizer_fn=lambda: flars_optimizer.FLARSOptimizer(learning_rate=
                                                               1.0)):
    """Builds the TFF computations for optimization using federated averaging.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.TrainableModel`.
    server_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer`.

  Returns:
    A `tff.utils.IterativeProcess`.
  """
    dummy_model_for_metadata = model_fn()

    type_signature_grads_norm = tff.NamedTupleType([
        weight.dtype for weight in tf.nest.flatten(
            _get_weights(dummy_model_for_metadata).trainable)
    ])

    server_init_tf = build_server_init_fn(model_fn, server_optimizer_fn)
    server_state_type = server_init_tf.type_signature.result
    server_update_fn = build_server_update_fn(model_fn, server_optimizer_fn,
                                              server_state_type,
                                              server_state_type.model,
                                              type_signature_grads_norm)

    tf_dataset_type = tff.SequenceType(dummy_model_for_metadata.input_spec)
    client_update_fn = build_client_update_fn(model_fn, tf_dataset_type,
                                              server_state_type.model)

    federated_server_state_type = tff.FederatedType(server_state_type,
                                                    tff.SERVER)
    federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)
    run_one_round_tff = build_run_one_round_fn(server_update_fn,
                                               client_update_fn,
                                               dummy_model_for_metadata,
                                               federated_server_state_type,
                                               federated_dataset_type)

    return tff.utils.IterativeProcess(initialize_fn=tff.federated_computation(
        lambda: tff.federated_value(server_init_tf(), tff.SERVER)),
                                      next_fn=run_one_round_tff)
Esempio n. 5
0
  def test_nan_examples_ignored(self):
    server_optimizer_fn = lambda: flars_optimizer.FLARSOptimizer(1.0)
    model = _keras_model_fn()
    server_optimizer = server_optimizer_fn()
    state, optimizer_vars = server_init(model, server_optimizer)

    grad_norm = [1.0, 1.0]
    weights_delta = tf.nest.map_structure(
        lambda t: tf.ones_like(t) * np.inf,
        flars_fedavg._get_weights(model).trainable)

    old_model_vars = self.evaluate(state.model)
    for _ in range(2):
      state = flars_fedavg.server_update(model, server_optimizer,
                                         optimizer_vars, state, weights_delta,
                                         grad_norm)
    model_vars = self.evaluate(state.model)

    self.assertAllClose(old_model_vars._asdict(), model_vars._asdict())
    def testFLARSGradientOneStep(self, dtype, momentum):
        shape = [3, 3]
        var_np = np.ones(shape)
        grad_np = np.ones(shape)
        lr_np = 0.1
        m_np = momentum
        ep_np = 1e-5
        eeta = 0.1
        vel_np = np.zeros(shape)

        var = tf.Variable(var_np, dtype=dtype, name='a')
        grad = tf.Variable(grad_np, dtype=dtype)

        opt = flars_optimizer.FLARSOptimizer(learning_rate=lr_np,
                                             momentum=m_np,
                                             eeta=eeta,
                                             epsilon=ep_np)

        g_norm = np.linalg.norm(grad_np.flatten(), ord=2)
        opt.update_grads_norm([var], [g_norm])

        self.evaluate(tf.compat.v1.global_variables_initializer())

        pre_var = self.evaluate(var)

        self.assertAllClose(var_np, pre_var)

        opt.apply_gradients([(grad, var)])

        post_var = self.evaluate(var)

        w_norm = np.linalg.norm(var_np.flatten(), ord=2)
        trust_ratio = eeta * w_norm / (g_norm + ep_np)
        scaled_lr = lr_np * trust_ratio

        vel_np = m_np * vel_np - scaled_lr * grad_np
        var_np += vel_np

        self.assertAllClose(var_np, post_var)
        if m_np != 0:
            post_vel = self.evaluate(opt.get_slot(var, 'momentum'))
            self.assertAllClose(vel_np, post_vel)
Esempio n. 7
0
  def test_nan_examples_ignored(self):
    server_optimizer_fn = lambda: flars_optimizer.FLARSOptimizer(1.0)
    model = _keras_model_fn()
    server_optimizer = server_optimizer_fn()
    state, optimizer_vars = server_init(model, server_optimizer)

    grad_norm = [1.0, 1.0]
    weights_delta = tf.nest.map_structure(
        lambda t: tf.ones_like(t) * float('inf'),
        flars_fedavg.tff.learning.framework.ModelWeights.from_model(
            model).trainable)

    old_model_vars = state.model
    for _ in range(2):
      state = flars_fedavg.server_update(model, server_optimizer,
                                         optimizer_vars, state, weights_delta,
                                         grad_norm)
    model_vars = state.model
    # Assert the model hasn't changed.
    self.assertAllClose(old_model_vars.trainable, model_vars.trainable)
    self.assertAllClose(old_model_vars.non_trainable, model_vars.non_trainable)