Пример #1
0
    def test_attack(self):
        """Test whether an attacker is doing the right attack."""
        def model_fn():
            return tff.learning.from_compiled_keras_model(
                tff.simulation.models.mnist.create_simple_keras_model(),
                sample_batch)

        client_data = create_client_data()
        batch = client_data()
        train_data = [batch]
        malicious_data = [batch]
        client_type_list = [tf.constant(True)]
        sample_batch = self.evaluate(next(iter(train_data[0])))
        trainer = build_federated_averaging_process_attacked(
            model_fn,
            client_update_tf=attacked_fedavg.ClientExplicitBoosting(
                boost_factor=-1.0))

        state = trainer.initialize()
        initial_weights = state.model.trainable
        for _ in range(2):
            state, _ = trainer.next(state, train_data, malicious_data,
                                    client_type_list)

        self.assertAllClose(initial_weights, state.model.trainable)
Пример #2
0
  def test_self_contained_example(self):
    client_data = create_client_data()
    model = MnistModel()
    optimizer = tf.keras.optimizers.SGD(0.1)
    losses = []
    client_update = attacked_fedavg.ClientExplicitBoosting(boost_factor=1.0)
    for _ in range(2):
      outputs = client_update(model, optimizer, client_data(), client_data(),
                              tf.constant(False),
                              attacked_fedavg._get_weights(model))
      losses.append(outputs.model_output['loss'].numpy())

    self.assertAllEqual(outputs.optimizer_output['num_examples'].numpy(), 2)
    self.assertLess(losses[1], losses[0])
    def test_attack(self):
        """Test whether an attacker is doing the right attack."""
        client_data = create_client_data()
        batch = client_data()
        train_data = [batch]
        malicious_data = [batch]
        client_type_list = [tf.constant(True)]
        trainer = build_federated_averaging_process_attacked(
            _model_fn,
            client_update_tf=attacked_fedavg.ClientExplicitBoosting(
                boost_factor=-1.0))
        state = trainer.initialize()
        initial_weights = state.model.trainable
        for _ in range(2):
            state, _ = trainer.next(state, train_data, malicious_data,
                                    client_type_list)

        self.assertAllClose(initial_weights, state.model.trainable)
Пример #4
0
  def test_attack(self):
    """Test whether an attacker is doing the right attack."""
    self.skipTest('b/150215351 This test became flaky after TF change which '
                  'removed variable reads from control_outputs.')
    client_data = create_client_data()
    batch = client_data()
    train_data = [batch]
    malicious_data = [batch]
    client_type_list = [tf.constant(True)]
    trainer = build_federated_averaging_process_attacked(
        _model_fn,
        client_update_tf=attacked_fedavg.ClientExplicitBoosting(
            boost_factor=-1.0))
    state = trainer.initialize()
    initial_weights = state.model.trainable
    for _ in range(2):
      state, _ = trainer.next(state, train_data, malicious_data,
                              client_type_list)

    self.assertAllClose(initial_weights, state.model.trainable)