コード例 #1
0
    def test_dp_fed_mean(self):
        """Test whether the norm clipping is done successfully."""
        def model_fn():
            return tff.learning.from_compiled_keras_model(
                tff.simulation.models.mnist.create_simple_keras_model(),
                sample_batch)

        client_data = create_client_data()
        batch = client_data()
        train_data = [batch]
        malicious_data = [batch]
        client_type_list = [tf.constant(False)]

        sample_batch = self.evaluate(next(iter(train_data[0])))
        l2_norm = 0.01
        dp_aggregate_fn = aggregate_fn.build_dp_aggregate(l2_norm, 0.0, 1.0)
        trainer = build_federated_averaging_process_attacked(
            model_fn, stateful_delta_aggregate_fn=dp_aggregate_fn)

        state = trainer.initialize()
        initial_weights = state.model.trainable

        state, _ = trainer.next(state, train_data, malicious_data,
                                client_type_list)

        weights_delta = tf.nest.map_structure(tf.subtract,
                                              state.model.trainable._asdict(),
                                              initial_weights._asdict())

        self.assertLess(attacked_fedavg._get_norm(weights_delta),
                        l2_norm * 1.1)
コード例 #2
0
 def test_dp_fed_mean(self):
   """Test whether the norm clipping is done successfully."""
   client_data = create_client_data()
   batch = client_data()
   train_data = [batch]
   malicious_data = [batch]
   client_type_list = [tf.constant(False)]
   l2_norm = 0.01
   dp_aggregate_fn = aggregate_fn.build_dp_aggregate(l2_norm, 0.0, 1.0)
   trainer = build_federated_averaging_process_attacked(
       _model_fn, stateful_delta_aggregate_fn=dp_aggregate_fn)
   state = trainer.initialize()
   initial_weights = state.model.trainable
   state, _ = trainer.next(state, train_data, malicious_data, client_type_list)
   weights_delta = tf.nest.map_structure(tf.subtract, state.model.trainable,
                                         initial_weights)
   self.assertLess(attacked_fedavg._get_norm(weights_delta), l2_norm * 1.1)
コード例 #3
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    config = tf.compat.v1.ConfigProto()
    config.graph_options.rewrite_options.layout_optimizer = 2
    tf.compat.v1.enable_eager_execution(config)

    np.random.seed(FLAGS.random_seed)

    flag_dict = FLAGS.flag_values_dict()
    configs = '-'.join([
        '{}={}'.format(k, flag_dict[k]) for k in keys if k != 'root_output_dir'
    ])
    file_name = 'log' + configs
    file_handle = open(os.path.join(FLAGS.root_output_dir, file_name), 'w')

    global_step = tf.Variable(1, name='global_step', dtype=tf.int64)
    file_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.root_output_dir))
    file_writer.set_as_default()
    write_print(file_handle, '=======configurations========')
    write_print(file_handle, configs)
    write_print(file_handle, '=======configurations========')
    # prepare dataset.
    write_print(file_handle, 'Loading Dataset!')
    emnist_train, _ = tff.simulation.datasets.emnist.load_data(
        only_digits=FLAGS.only_digits)

    # prepare test set
    write_print(file_handle, 'Loading Test Set!')
    test_image, test_label = load_test_data()

    # load malicious dataset
    write_print(file_handle, 'Loading malicious dataset!')
    dataset_malicious, target_x, target_y = load_malicious_dataset(
        FLAGS.task_num)

    # prepare model_fn.
    example_dataset = emnist_train.create_tf_dataset_for_client(
        emnist_train.client_ids[0])
    input_spec = example_dataset.element_spec

    def model_fn():
        keras_model = create_keras_model()
        return tff.learning.from_keras_model(
            keras_model,
            input_spec=input_spec,
            loss=tf.keras.losses.SparseCategoricalCrossentropy(),
            metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

    # define server optimizer
    nesterov = True if FLAGS.server_momentum != 0 else False

    def server_optimizer_fn():
        return tf.keras.optimizers.SGD(
            learning_rate=FLAGS.server_learning_rate,
            momentum=FLAGS.server_momentum,
            nesterov=nesterov)

    # build interative process
    write_print(file_handle, 'Building Iterative Process!')
    client_update_function = attacked_fedavg.ClientProjectBoost(
        boost_factor=float(FLAGS.num_clients_per_round),
        norm_bound=FLAGS.norm_bound,
        round_num=FLAGS.client_round_num)
    aggregation_function = aggregate_fn.build_dp_aggregate(
        l2_norm=FLAGS.l2_norm_clip,
        mul_factor=FLAGS.mul_factor,
        num_clients=FLAGS.num_clients_per_round)

    iterative_process = attacked_fedavg.build_federated_averaging_process_attacked(
        model_fn=model_fn,
        stateful_delta_aggregate_fn=aggregation_function,
        client_update_tf=client_update_function,
        server_optimizer_fn=server_optimizer_fn)
    state = iterative_process.initialize()

    # training loop
    for cur_round in range(FLAGS.num_rounds):
        if cur_round % FLAGS.attack_freq == FLAGS.attack_freq // 2:
            with_attack = 1
            write_print(file_handle, 'Attacker appears!')
        else:
            with_attack = 0

        # sample clients and make federated dataset
        federated_train_data, federated_malicious_data, client_type_list = \
            sample_clients_with_malicious(
                emnist_train, client_ids=emnist_train.client_ids,
                dataset_malicious=dataset_malicious,
                num_clients=FLAGS.num_clients_per_round, with_attack=with_attack)

        # one round of attacked federated averaging
        write_print(file_handle, 'Round starts!')
        state, train_metrics = iterative_process.next(
            state, federated_train_data, federated_malicious_data,
            client_type_list)

        write_print(
            file_handle, 'Training round {:2d}, train_metrics={}'.format(
                cur_round, train_metrics))

        log_tfboard('train_acc', train_metrics[0], global_step)
        log_tfboard('train_loss', train_metrics[1], global_step)

        # evaluate current model on test data and malicious data
        if cur_round % FLAGS.evaluate_per_rounds == 0:
            test_metrics, test_metrics_target = evaluate(
                state, test_image, test_label, target_x, target_y)
            write_print(
                file_handle,
                'Evaluation round {:2d}, <sparse_categorical_accuracy={},loss={}>'
                .format(cur_round, test_metrics[1], test_metrics[0]))
            write_print(
                file_handle,
                'Evaluation round {:2d}, <sparse_categorical_accuracy={},loss={}>'
                .format(cur_round, test_metrics_target[1],
                        test_metrics_target[0]))
            log_tfboard('test_acc', test_metrics[1], global_step)
            log_tfboard('test_loss', test_metrics[0], global_step)
            log_tfboard('test_acc_target', test_metrics_target[1], global_step)
            log_tfboard('test_loss_target', test_metrics_target[0],
                        global_step)

        global_step.assign_add(1)