def test_input_types(self): it_process = build_federated_averaging_process_attacked(_model_fn) self.assertIsInstance(it_process, tff.templates.IterativeProcess) federated_data_type = it_process.next.type_signature.parameter[1] self.assertEqual(str(federated_data_type), '{<x=float32[?,784],y=int64[?,1]>*}@CLIENTS') self.assertEqual(str(it_process.next.type_signature.parameter[1]), str(it_process.next.type_signature.parameter[2])) federated_bool_type = it_process.next.type_signature.parameter[3] self.assertEqual(str(federated_bool_type), '{bool}@CLIENTS')
def test_self_contained_example_custom_model(self): client_data = create_client_data() batch = client_data() train_data = [batch] malicious_data = [batch] client_type_list = [tf.constant(False)] trainer = build_federated_averaging_process_attacked(MnistModel) state = trainer.initialize() losses = [] for _ in range(2): state, outputs = trainer.next(state, train_data, malicious_data, client_type_list) losses.append(outputs['loss']) self.assertLess(losses[1], losses[0])
def test_attack(self): """Test whether an attacker is doing the right attack.""" self.skipTest('b/150215351 This test became flaky after TF change which ' 'removed variable reads from control_outputs.') client_data = create_client_data() batch = client_data() train_data = [batch] malicious_data = [batch] client_type_list = [tf.constant(True)] trainer = build_federated_averaging_process_attacked( _model_fn, client_update_tf=attacked_fedavg.ClientExplicitBoosting( boost_factor=-1.0)) state = trainer.initialize() initial_weights = state.model.trainable for _ in range(2): state, _ = trainer.next(state, train_data, malicious_data, client_type_list) self.assertAllClose(initial_weights, state.model.trainable)
def test_dp_fed_mean(self): """Test whether the norm clipping is done successfully.""" client_data = create_client_data() batch = client_data() train_data = [batch] malicious_data = [batch] client_type_list = [tf.constant(False)] l2_norm = 0.01 query = tensorflow_privacy.GaussianAverageQuery(l2_norm, 0.0, 1.0) dp_agg_factory = tff.aggregators.DifferentiallyPrivateFactory(query) aggregation_process = dp_agg_factory.create( tff.learning.framework.weights_type_from_model(_model_fn).trainable) trainer = build_federated_averaging_process_attacked( _model_fn, aggregation_process=aggregation_process) state = trainer.initialize() initial_weights = state.model.trainable state, _ = trainer.next(state, train_data, malicious_data, client_type_list) weights_delta = tf.nest.map_structure(tf.subtract, state.model.trainable, initial_weights) self.assertLess(attacked_fedavg._get_norm(weights_delta), l2_norm * 1.1)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') config = tf.compat.v1.ConfigProto() config.graph_options.rewrite_options.layout_optimizer = 2 tf.compat.v1.enable_eager_execution(config) np.random.seed(FLAGS.random_seed) flag_dict = FLAGS.flag_values_dict() configs = '-'.join([ '{}={}'.format(k, flag_dict[k]) for k in keys if k != 'root_output_dir' ]) file_name = 'log' + configs create_if_not_exists(FLAGS.root_output_dir) file_handle = open(os.path.join(FLAGS.root_output_dir, file_name), 'w') global_step = tf.Variable(1, name='global_step', dtype=tf.int64) file_writer = tf.summary.create_file_writer( os.path.join(FLAGS.root_output_dir)) file_writer.set_as_default() write_print(file_handle, '=======configurations========') write_print(file_handle, configs) write_print(file_handle, '=======configurations========') # prepare dataset. write_print(file_handle, 'Loading Dataset!') emnist_train, _ = tff.simulation.datasets.emnist.load_data( only_digits=FLAGS.only_digits) # prepare test set write_print(file_handle, 'Loading Test Set!') test_image, test_label = load_test_data() # load malicious dataset write_print(file_handle, 'Loading malicious dataset!') dataset_malicious, target_x, target_y = load_malicious_dataset( FLAGS.task_num) # prepare model_fn. example_dataset = preprocess( emnist_train.create_tf_dataset_for_client(emnist_train.client_ids[0])) input_spec = example_dataset.element_spec def model_fn(): keras_model = create_keras_model() return tff.learning.from_keras_model( keras_model, input_spec=input_spec, loss=tf.keras.losses.SparseCategoricalCrossentropy(), metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]) # define server optimizer nesterov = True if FLAGS.server_momentum != 0 else False def server_optimizer_fn(): return tf.keras.optimizers.SGD( learning_rate=FLAGS.server_learning_rate, momentum=FLAGS.server_momentum, nesterov=nesterov) # build interative process write_print(file_handle, 'Building Iterative Process!') client_update_function = attacked_fedavg.ClientProjectBoost( boost_factor=float(FLAGS.num_clients_per_round), norm_bound=FLAGS.norm_bound, round_num=FLAGS.client_round_num) query = tensorflow_privacy.GaussianAverageQuery( FLAGS.l2_norm_clip, FLAGS.mul_factor, FLAGS.num_clients_per_round) dp_aggregate_fn = tff.utils.build_dp_aggregate_process( tff.learning.framework.weights_type_from_model(model_fn), query) iterative_process = attacked_fedavg.build_federated_averaging_process_attacked( model_fn=model_fn, aggregation_process=dp_aggregate_fn, client_update_tf=client_update_function, server_optimizer_fn=server_optimizer_fn) state = iterative_process.initialize() # training loop for cur_round in range(FLAGS.num_rounds): if cur_round % FLAGS.attack_freq == FLAGS.attack_freq // 2: with_attack = 1 write_print(file_handle, 'Attacker appears!') else: with_attack = 0 # sample clients and make federated dataset federated_train_data, federated_malicious_data, client_type_list = \ sample_clients_with_malicious( emnist_train, client_ids=emnist_train.client_ids, dataset_malicious=dataset_malicious, num_clients=FLAGS.num_clients_per_round, with_attack=with_attack) # one round of attacked federated averaging write_print(file_handle, 'Round starts!') state, train_metrics = iterative_process.next( state, federated_train_data, federated_malicious_data, client_type_list) write_print( file_handle, 'Training round {:2d}, train_metrics={}'.format( cur_round, train_metrics)) log_tfboard('train_acc', train_metrics['sparse_categorical_accuracy'], global_step) log_tfboard('train_loss', train_metrics['loss'], global_step) # evaluate current model on test data and malicious data if cur_round % FLAGS.evaluate_per_rounds == 0: test_metrics, test_metrics_target = evaluate( state, test_image, test_label, target_x, target_y) write_print( file_handle, 'Evaluation round {:2d}, <sparse_categorical_accuracy={},loss={}>' .format(cur_round, test_metrics[1], test_metrics[0])) write_print( file_handle, 'Evaluation round {:2d}, <sparse_categorical_accuracy={},loss={}>' .format(cur_round, test_metrics_target[1], test_metrics_target[0])) log_tfboard('test_acc', test_metrics[1], global_step) log_tfboard('test_loss', test_metrics[0], global_step) log_tfboard('test_acc_target', test_metrics_target[1], global_step) log_tfboard('test_loss_target', test_metrics_target[0], global_step) global_step.assign_add(1)