def test_input_types(self):
     it_process = build_federated_averaging_process_attacked(_model_fn)
     self.assertIsInstance(it_process, tff.templates.IterativeProcess)
     federated_data_type = it_process.next.type_signature.parameter[1]
     self.assertEqual(str(federated_data_type),
                      '{<x=float32[?,784],y=int64[?,1]>*}@CLIENTS')
     self.assertEqual(str(it_process.next.type_signature.parameter[1]),
                      str(it_process.next.type_signature.parameter[2]))
     federated_bool_type = it_process.next.type_signature.parameter[3]
     self.assertEqual(str(federated_bool_type), '{bool}@CLIENTS')
 def test_self_contained_example_custom_model(self):
     client_data = create_client_data()
     batch = client_data()
     train_data = [batch]
     malicious_data = [batch]
     client_type_list = [tf.constant(False)]
     trainer = build_federated_averaging_process_attacked(MnistModel)
     state = trainer.initialize()
     losses = []
     for _ in range(2):
         state, outputs = trainer.next(state, train_data, malicious_data,
                                       client_type_list)
         losses.append(outputs['loss'])
     self.assertLess(losses[1], losses[0])
示例#3
0
  def test_attack(self):
    """Test whether an attacker is doing the right attack."""
    self.skipTest('b/150215351 This test became flaky after TF change which '
                  'removed variable reads from control_outputs.')
    client_data = create_client_data()
    batch = client_data()
    train_data = [batch]
    malicious_data = [batch]
    client_type_list = [tf.constant(True)]
    trainer = build_federated_averaging_process_attacked(
        _model_fn,
        client_update_tf=attacked_fedavg.ClientExplicitBoosting(
            boost_factor=-1.0))
    state = trainer.initialize()
    initial_weights = state.model.trainable
    for _ in range(2):
      state, _ = trainer.next(state, train_data, malicious_data,
                              client_type_list)

    self.assertAllClose(initial_weights, state.model.trainable)
示例#4
0
 def test_dp_fed_mean(self):
   """Test whether the norm clipping is done successfully."""
   client_data = create_client_data()
   batch = client_data()
   train_data = [batch]
   malicious_data = [batch]
   client_type_list = [tf.constant(False)]
   l2_norm = 0.01
   query = tensorflow_privacy.GaussianAverageQuery(l2_norm, 0.0, 1.0)
   dp_agg_factory = tff.aggregators.DifferentiallyPrivateFactory(query)
   aggregation_process = dp_agg_factory.create(
       tff.learning.framework.weights_type_from_model(_model_fn).trainable)
   trainer = build_federated_averaging_process_attacked(
       _model_fn, aggregation_process=aggregation_process)
   state = trainer.initialize()
   initial_weights = state.model.trainable
   state, _ = trainer.next(state, train_data, malicious_data, client_type_list)
   weights_delta = tf.nest.map_structure(tf.subtract, state.model.trainable,
                                         initial_weights)
   self.assertLess(attacked_fedavg._get_norm(weights_delta), l2_norm * 1.1)
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    config = tf.compat.v1.ConfigProto()
    config.graph_options.rewrite_options.layout_optimizer = 2
    tf.compat.v1.enable_eager_execution(config)

    np.random.seed(FLAGS.random_seed)

    flag_dict = FLAGS.flag_values_dict()
    configs = '-'.join([
        '{}={}'.format(k, flag_dict[k]) for k in keys if k != 'root_output_dir'
    ])
    file_name = 'log' + configs
    create_if_not_exists(FLAGS.root_output_dir)
    file_handle = open(os.path.join(FLAGS.root_output_dir, file_name), 'w')

    global_step = tf.Variable(1, name='global_step', dtype=tf.int64)
    file_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.root_output_dir))
    file_writer.set_as_default()
    write_print(file_handle, '=======configurations========')
    write_print(file_handle, configs)
    write_print(file_handle, '=======configurations========')
    # prepare dataset.
    write_print(file_handle, 'Loading Dataset!')
    emnist_train, _ = tff.simulation.datasets.emnist.load_data(
        only_digits=FLAGS.only_digits)

    # prepare test set
    write_print(file_handle, 'Loading Test Set!')
    test_image, test_label = load_test_data()

    # load malicious dataset
    write_print(file_handle, 'Loading malicious dataset!')
    dataset_malicious, target_x, target_y = load_malicious_dataset(
        FLAGS.task_num)

    # prepare model_fn.
    example_dataset = preprocess(
        emnist_train.create_tf_dataset_for_client(emnist_train.client_ids[0]))
    input_spec = example_dataset.element_spec

    def model_fn():
        keras_model = create_keras_model()
        return tff.learning.from_keras_model(
            keras_model,
            input_spec=input_spec,
            loss=tf.keras.losses.SparseCategoricalCrossentropy(),
            metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

    # define server optimizer
    nesterov = True if FLAGS.server_momentum != 0 else False

    def server_optimizer_fn():
        return tf.keras.optimizers.SGD(
            learning_rate=FLAGS.server_learning_rate,
            momentum=FLAGS.server_momentum,
            nesterov=nesterov)

    # build interative process
    write_print(file_handle, 'Building Iterative Process!')
    client_update_function = attacked_fedavg.ClientProjectBoost(
        boost_factor=float(FLAGS.num_clients_per_round),
        norm_bound=FLAGS.norm_bound,
        round_num=FLAGS.client_round_num)
    query = tensorflow_privacy.GaussianAverageQuery(
        FLAGS.l2_norm_clip, FLAGS.mul_factor, FLAGS.num_clients_per_round)
    dp_aggregate_fn = tff.utils.build_dp_aggregate_process(
        tff.learning.framework.weights_type_from_model(model_fn), query)
    iterative_process = attacked_fedavg.build_federated_averaging_process_attacked(
        model_fn=model_fn,
        aggregation_process=dp_aggregate_fn,
        client_update_tf=client_update_function,
        server_optimizer_fn=server_optimizer_fn)
    state = iterative_process.initialize()

    # training loop
    for cur_round in range(FLAGS.num_rounds):
        if cur_round % FLAGS.attack_freq == FLAGS.attack_freq // 2:
            with_attack = 1
            write_print(file_handle, 'Attacker appears!')
        else:
            with_attack = 0

        # sample clients and make federated dataset
        federated_train_data, federated_malicious_data, client_type_list = \
            sample_clients_with_malicious(
                emnist_train, client_ids=emnist_train.client_ids,
                dataset_malicious=dataset_malicious,
                num_clients=FLAGS.num_clients_per_round, with_attack=with_attack)

        # one round of attacked federated averaging
        write_print(file_handle, 'Round starts!')
        state, train_metrics = iterative_process.next(
            state, federated_train_data, federated_malicious_data,
            client_type_list)

        write_print(
            file_handle, 'Training round {:2d}, train_metrics={}'.format(
                cur_round, train_metrics))

        log_tfboard('train_acc', train_metrics['sparse_categorical_accuracy'],
                    global_step)
        log_tfboard('train_loss', train_metrics['loss'], global_step)

        # evaluate current model on test data and malicious data
        if cur_round % FLAGS.evaluate_per_rounds == 0:
            test_metrics, test_metrics_target = evaluate(
                state, test_image, test_label, target_x, target_y)
            write_print(
                file_handle,
                'Evaluation round {:2d}, <sparse_categorical_accuracy={},loss={}>'
                .format(cur_round, test_metrics[1], test_metrics[0]))
            write_print(
                file_handle,
                'Evaluation round {:2d}, <sparse_categorical_accuracy={},loss={}>'
                .format(cur_round, test_metrics_target[1],
                        test_metrics_target[0]))
            log_tfboard('test_acc', test_metrics[1], global_step)
            log_tfboard('test_loss', test_metrics[0], global_step)
            log_tfboard('test_acc_target', test_metrics_target[1], global_step)
            log_tfboard('test_loss_target', test_metrics_target[0],
                        global_step)

        global_step.assign_add(1)