def test_dp_fed_mean(self): """Test whether the norm clipping is done successfully.""" def model_fn(): return tff.learning.from_compiled_keras_model( tff.simulation.models.mnist.create_simple_keras_model(), sample_batch) client_data = create_client_data() batch = client_data() train_data = [batch] malicious_data = [batch] client_type_list = [tf.constant(False)] sample_batch = self.evaluate(next(iter(train_data[0]))) l2_norm = 0.01 dp_aggregate_fn = aggregate_fn.build_dp_aggregate(l2_norm, 0.0, 1.0) trainer = build_federated_averaging_process_attacked( model_fn, stateful_delta_aggregate_fn=dp_aggregate_fn) state = trainer.initialize() initial_weights = state.model.trainable state, _ = trainer.next(state, train_data, malicious_data, client_type_list) weights_delta = tf.nest.map_structure(tf.subtract, state.model.trainable._asdict(), initial_weights._asdict()) self.assertLess(attacked_fedavg._get_norm(weights_delta), l2_norm * 1.1)
def test_dp_fed_mean(self): """Test whether the norm clipping is done successfully.""" client_data = create_client_data() batch = client_data() train_data = [batch] malicious_data = [batch] client_type_list = [tf.constant(False)] l2_norm = 0.01 dp_aggregate_fn = aggregate_fn.build_dp_aggregate(l2_norm, 0.0, 1.0) trainer = build_federated_averaging_process_attacked( _model_fn, stateful_delta_aggregate_fn=dp_aggregate_fn) state = trainer.initialize() initial_weights = state.model.trainable state, _ = trainer.next(state, train_data, malicious_data, client_type_list) weights_delta = tf.nest.map_structure(tf.subtract, state.model.trainable, initial_weights) self.assertLess(attacked_fedavg._get_norm(weights_delta), l2_norm * 1.1)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') config = tf.compat.v1.ConfigProto() config.graph_options.rewrite_options.layout_optimizer = 2 tf.compat.v1.enable_eager_execution(config) np.random.seed(FLAGS.random_seed) flag_dict = FLAGS.flag_values_dict() configs = '-'.join([ '{}={}'.format(k, flag_dict[k]) for k in keys if k != 'root_output_dir' ]) file_name = 'log' + configs file_handle = open(os.path.join(FLAGS.root_output_dir, file_name), 'w') global_step = tf.Variable(1, name='global_step', dtype=tf.int64) file_writer = tf.summary.create_file_writer( os.path.join(FLAGS.root_output_dir)) file_writer.set_as_default() write_print(file_handle, '=======configurations========') write_print(file_handle, configs) write_print(file_handle, '=======configurations========') # prepare dataset. write_print(file_handle, 'Loading Dataset!') emnist_train, _ = tff.simulation.datasets.emnist.load_data( only_digits=FLAGS.only_digits) # prepare test set write_print(file_handle, 'Loading Test Set!') test_image, test_label = load_test_data() # load malicious dataset write_print(file_handle, 'Loading malicious dataset!') dataset_malicious, target_x, target_y = load_malicious_dataset( FLAGS.task_num) # prepare model_fn. example_dataset = emnist_train.create_tf_dataset_for_client( emnist_train.client_ids[0]) input_spec = example_dataset.element_spec def model_fn(): keras_model = create_keras_model() return tff.learning.from_keras_model( keras_model, input_spec=input_spec, loss=tf.keras.losses.SparseCategoricalCrossentropy(), metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]) # define server optimizer nesterov = True if FLAGS.server_momentum != 0 else False def server_optimizer_fn(): return tf.keras.optimizers.SGD( learning_rate=FLAGS.server_learning_rate, momentum=FLAGS.server_momentum, nesterov=nesterov) # build interative process write_print(file_handle, 'Building Iterative Process!') client_update_function = attacked_fedavg.ClientProjectBoost( boost_factor=float(FLAGS.num_clients_per_round), norm_bound=FLAGS.norm_bound, round_num=FLAGS.client_round_num) aggregation_function = aggregate_fn.build_dp_aggregate( l2_norm=FLAGS.l2_norm_clip, mul_factor=FLAGS.mul_factor, num_clients=FLAGS.num_clients_per_round) iterative_process = attacked_fedavg.build_federated_averaging_process_attacked( model_fn=model_fn, stateful_delta_aggregate_fn=aggregation_function, client_update_tf=client_update_function, server_optimizer_fn=server_optimizer_fn) state = iterative_process.initialize() # training loop for cur_round in range(FLAGS.num_rounds): if cur_round % FLAGS.attack_freq == FLAGS.attack_freq // 2: with_attack = 1 write_print(file_handle, 'Attacker appears!') else: with_attack = 0 # sample clients and make federated dataset federated_train_data, federated_malicious_data, client_type_list = \ sample_clients_with_malicious( emnist_train, client_ids=emnist_train.client_ids, dataset_malicious=dataset_malicious, num_clients=FLAGS.num_clients_per_round, with_attack=with_attack) # one round of attacked federated averaging write_print(file_handle, 'Round starts!') state, train_metrics = iterative_process.next( state, federated_train_data, federated_malicious_data, client_type_list) write_print( file_handle, 'Training round {:2d}, train_metrics={}'.format( cur_round, train_metrics)) log_tfboard('train_acc', train_metrics[0], global_step) log_tfboard('train_loss', train_metrics[1], global_step) # evaluate current model on test data and malicious data if cur_round % FLAGS.evaluate_per_rounds == 0: test_metrics, test_metrics_target = evaluate( state, test_image, test_label, target_x, target_y) write_print( file_handle, 'Evaluation round {:2d}, <sparse_categorical_accuracy={},loss={}>' .format(cur_round, test_metrics[1], test_metrics[0])) write_print( file_handle, 'Evaluation round {:2d}, <sparse_categorical_accuracy={},loss={}>' .format(cur_round, test_metrics_target[1], test_metrics_target[0])) log_tfboard('test_acc', test_metrics[1], global_step) log_tfboard('test_loss', test_metrics[0], global_step) log_tfboard('test_acc_target', test_metrics_target[1], global_step) log_tfboard('test_loss_target', test_metrics_target[0], global_step) global_step.assign_add(1)