def test_attack(self): """Test whether an attacker is doing the right attack.""" def model_fn(): return tff.learning.from_compiled_keras_model( tff.simulation.models.mnist.create_simple_keras_model(), sample_batch) client_data = create_client_data() batch = client_data() train_data = [batch] malicious_data = [batch] client_type_list = [tf.constant(True)] sample_batch = self.evaluate(next(iter(train_data[0]))) trainer = build_federated_averaging_process_attacked( model_fn, client_update_tf=attacked_fedavg.ClientExplicitBoosting( boost_factor=-1.0)) state = trainer.initialize() initial_weights = state.model.trainable for _ in range(2): state, _ = trainer.next(state, train_data, malicious_data, client_type_list) self.assertAllClose(initial_weights, state.model.trainable)
def test_self_contained_example(self): client_data = create_client_data() model = MnistModel() optimizer = tf.keras.optimizers.SGD(0.1) losses = [] client_update = attacked_fedavg.ClientExplicitBoosting(boost_factor=1.0) for _ in range(2): outputs = client_update(model, optimizer, client_data(), client_data(), tf.constant(False), attacked_fedavg._get_weights(model)) losses.append(outputs.model_output['loss'].numpy()) self.assertAllEqual(outputs.optimizer_output['num_examples'].numpy(), 2) self.assertLess(losses[1], losses[0])
def test_attack(self): """Test whether an attacker is doing the right attack.""" client_data = create_client_data() batch = client_data() train_data = [batch] malicious_data = [batch] client_type_list = [tf.constant(True)] trainer = build_federated_averaging_process_attacked( _model_fn, client_update_tf=attacked_fedavg.ClientExplicitBoosting( boost_factor=-1.0)) state = trainer.initialize() initial_weights = state.model.trainable for _ in range(2): state, _ = trainer.next(state, train_data, malicious_data, client_type_list) self.assertAllClose(initial_weights, state.model.trainable)
def test_attack(self): """Test whether an attacker is doing the right attack.""" self.skipTest('b/150215351 This test became flaky after TF change which ' 'removed variable reads from control_outputs.') client_data = create_client_data() batch = client_data() train_data = [batch] malicious_data = [batch] client_type_list = [tf.constant(True)] trainer = build_federated_averaging_process_attacked( _model_fn, client_update_tf=attacked_fedavg.ClientExplicitBoosting( boost_factor=-1.0)) state = trainer.initialize() initial_weights = state.model.trainable for _ in range(2): state, _ = trainer.next(state, train_data, malicious_data, client_type_list) self.assertAllClose(initial_weights, state.model.trainable)