예제 #1
0
    def _assert_server_update_with_all_ones(self, model_fn):
        optimizer_fn = lambda: tf.keras.optimizers.SGD(learning_rate=0.1)
        model = tf.keras.models.Sequential([
            tf.keras.layers.Input(shape=(784, )),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(units=10, kernel_initializer='zeros'),
            tf.keras.layers.Softmax(),
        ])
        optimizer = optimizer_fn()
        state, optimizer_vars = server_init(model, optimizer)
        weights_delta = tf.nest.map_structure(
            tf.ones_like,
            attacked_fedavg._get_weights(model).trainable)

        for _ in range(2):
            state = attacked_fedavg.server_update(model, optimizer,
                                                  optimizer_vars, state,
                                                  weights_delta, ())

        model_vars = self.evaluate(state.model)
        train_vars = model_vars.trainable
        # weights are initialized with all-zeros, weights_delta is all ones,
        # SGD learning rate is 0.1. Updating server for 2 steps.
        self.assertAllClose(
            train_vars,
            tf.nest.map_structure(lambda t: tf.fill(tf.shape(t), 0.2),
                                  train_vars))
예제 #2
0
  def test_self_contained_example(self):
    client_data = create_client_data()
    model = MnistModel()
    optimizer = tf.keras.optimizers.SGD(0.1)
    losses = []
    client_update = attacked_fedavg.ClientExplicitBoosting(boost_factor=1.0)
    for _ in range(2):
      outputs = client_update(model, optimizer, client_data(), client_data(),
                              tf.constant(False),
                              attacked_fedavg._get_weights(model))
      losses.append(outputs.model_output['loss'].numpy())

    self.assertAllEqual(outputs.optimizer_output['num_examples'].numpy(), 2)
    self.assertLess(losses[1], losses[0])
예제 #3
0
def server_init(model, optimizer, delta_aggregate_state=()):
    """Returns initial `tff.learning.framework.ServerState`.

  Args:
    model: A `tff.learning.Model`.
    optimizer: A `tf.train.Optimizer`.
    delta_aggregate_state: A server state.

  Returns:
    A `tff.learning.framework.ServerState` namedtuple.
  """
    optimizer_vars = attacked_fedavg._create_optimizer_vars(model, optimizer)
    return (attacked_fedavg.ServerState(
        model=attacked_fedavg._get_weights(model),
        optimizer_state=optimizer_vars,
        delta_aggregate_state=delta_aggregate_state), optimizer_vars)