def _assert_server_update_with_all_ones(self, model_fn):
    model = model_fn()
    state = _server_init(model, optimizer_utils.SGDServerOptimizer(0.1))
    weights_delta = tf.nest.map_structure(tf.ones_like,
                                          model.trainable_variables)

    example_optimizer = optimizer_utils.SGDServerOptimizer(0.10)
    for _ in range(2):
      state = dp_fedavg.server_update(model, example_optimizer, state,
                                      weights_delta)

    model_vars = self.evaluate(state.model)
    train_vars = model_vars.trainable
    self.assertLen(train_vars, 2)
    self.assertEqual(state.round_num, 2)
    # weights are initialized with all-zeros, weights_delta is all ones,
    # SGD learning rate is 0.1. Updating server for 2 steps.
    self.assertAllClose(train_vars, [np.ones_like(v) * 0.2 for v in train_vars])
  def test_deterministic_sgd(self):
    model_variables = _create_model_variables()
    grad = tf.nest.map_structure(tf.ones_like, model_variables)
    optimizer = optimizer_utils.SGDServerOptimizer(learning_rate=0.1)

    state = optimizer.init_state()
    for i in range(2):
      state = optimizer.model_update(state, model_variables, grad, i)

    self.assertLen(model_variables, 2)
    # variables initialize with all zeros and update with all ones and learning
    # rate 0.1 for several steps.
    flatten_variables = tf.nest.flatten(model_variables)
    self.assertAllClose(flatten_variables,
                        [-0.2 * np.ones_like(v) for v in flatten_variables])
def _server_optimizer_fn(model_weights, name, learning_rate, noise_std):
    """Returns server optimizer."""
    model_weight_specs = tf.nest.map_structure(
        lambda v: tf.TensorSpec(v.shape, v.dtype), model_weights)
    if name == 'sgd':
        return optimizer_utils.SGDServerOptimizer(learning_rate)
    elif name == 'sgdm':
        return optimizer_utils.DPSGDMServerOptimizer(
            learning_rate,
            momentum=FLAGS.server_momentum,
            noise_std=0,
            model_weight_specs=model_weight_specs)
    elif name == 'dpftrl':
        return optimizer_utils.DPFTRLMServerOptimizer(
            learning_rate,
            momentum=0,
            noise_std=noise_std,
            model_weight_specs=model_weight_specs)
    elif name == 'dpsgd':
        return optimizer_utils.DPSGDMServerOptimizer(
            learning_rate,
            momentum=0,
            noise_std=noise_std,
            model_weight_specs=model_weight_specs)
    elif name == 'dpsgdm':
        return optimizer_utils.DPSGDMServerOptimizer(
            learning_rate,
            momentum=FLAGS.server_momentum,
            noise_std=noise_std,
            model_weight_specs=model_weight_specs)
    elif name == 'dpftrlm':
        return optimizer_utils.DPFTRLMServerOptimizer(
            learning_rate,
            momentum=FLAGS.server_momentum,
            noise_std=noise_std,
            model_weight_specs=model_weight_specs)
    else:
        raise ValueError('Unknown server optimizer name {}'.format(name))
Beispiel #4
0
def build_federated_averaging_process(
    model_fn,
    dp_clip_norm=1.0,
    server_optimizer_fn=lambda w: optimizer_utils.SGDServerOptimizer(  # pylint: disable=g-long-lambda
        learning_rate=1.0),
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.1)):
    """Builds the TFF computations for optimization using federated averaging.

  Args:
    model_fn: A no-arg function that returns a `dp_fedavg_tf.KerasModelWrapper`.
    dp_clip_norm: if < 0, no clipping
    server_optimizer_fn: .
    client_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer` for client update.

  Returns:
    A `tff.templates.IterativeProcess`.
  """

    example_model = model_fn()

    @tff.tf_computation
    def server_init_tf():
        model = model_fn()
        optimizer = server_optimizer_fn(model.weights.trainable)
        return ServerState(model_weights=model.weights,
                           optimizer_state=optimizer.init_state(),
                           round_num=0,
                           dp_clip_norm=dp_clip_norm)

    server_state_type = server_init_tf.type_signature.result

    model_weights_type = server_state_type.model_weights

    @tff.tf_computation(server_state_type, model_weights_type.trainable)
    def server_update_fn(server_state, model_delta):
        model = model_fn()
        optimizer = server_optimizer_fn(model.weights.trainable)
        return server_update(model, optimizer, server_state, model_delta)

    @tff.tf_computation(server_state_type)
    def server_message_fn(server_state):
        return build_server_broadcast_message(server_state)

    server_message_type = server_message_fn.type_signature.result
    tf_dataset_type = tff.SequenceType(example_model.input_spec)

    @tff.tf_computation(tf_dataset_type, server_message_type)
    def client_update_fn(tf_dataset, server_message):
        model = model_fn()
        client_optimizer = client_optimizer_fn()
        return client_update(model, tf_dataset, server_message,
                             client_optimizer)

    federated_server_state_type = tff.type_at_server(server_state_type)
    federated_dataset_type = tff.type_at_clients(tf_dataset_type)

    @tff.federated_computation(federated_server_state_type,
                               federated_dataset_type)
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.data.Dataset` with placement
        `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and `tf.Tensor` of average loss.
    """
        server_message = tff.federated_map(server_message_fn, server_state)
        server_message_at_client = tff.federated_broadcast(server_message)

        client_outputs = tff.federated_map(
            client_update_fn, (federated_dataset, server_message_at_client))

        # Model deltas are equally weighted in DP.
        round_model_delta = tff.federated_mean(client_outputs.weights_delta)

        server_state = tff.federated_map(server_update_fn,
                                         (server_state, round_model_delta))
        round_loss_metric = tff.federated_mean(client_outputs.model_output)

        return server_state, round_loss_metric

    @tff.federated_computation
    def server_init_tff():
        """Orchestration logic for server model initialization."""
        return tff.federated_value(server_init_tf(), tff.SERVER)

    return tff.templates.IterativeProcess(initialize_fn=server_init_tff,
                                          next_fn=run_one_round)