Exemplo n.º 1
0
  def test_self_contained_example(self):
    client_data = create_client_data()
    model = MnistModel()
    optimizer = tf.keras.optimizers.SGD(0.1)
    losses = []
    client_update = attacked_fedavg.ClientExplicitBoosting(boost_factor=1.0)
    for _ in range(2):
      outputs = client_update(model, optimizer, client_data(), client_data(),
                              tf.constant(False),
                              attacked_fedavg._get_weights(model))
      losses.append(outputs.model_output['loss'].numpy())

    self.assertAllEqual(outputs.optimizer_output['num_examples'].numpy(), 2)
    self.assertLess(losses[1], losses[0])
Exemplo n.º 2
0
  def test_attack(self):
    """Test whether an attacker is doing the right attack."""
    self.skipTest('b/150215351 This test became flaky after TF change which '
                  'removed variable reads from control_outputs.')
    client_data = create_client_data()
    batch = client_data()
    train_data = [batch]
    malicious_data = [batch]
    client_type_list = [tf.constant(True)]
    trainer = build_federated_averaging_process_attacked(
        _model_fn,
        client_update_tf=attacked_fedavg.ClientExplicitBoosting(
            boost_factor=-1.0))
    state = trainer.initialize()
    initial_weights = state.model.trainable
    for _ in range(2):
      state, _ = trainer.next(state, train_data, malicious_data,
                              client_type_list)

    self.assertAllClose(initial_weights, state.model.trainable)