示例#1
0
    def test_aggregate_and_clip(self):
        """Test whether the norm clipping is done successfully."""
        def model_fn():
            return tff.learning.from_compiled_keras_model(
                tff.simulation.models.mnist.create_simple_keras_model(),
                sample_batch)

        client_data = create_client_data()
        batch = client_data()
        train_data = [batch]
        malicious_data = [batch]
        client_type_list = [tf.constant(False)]

        sample_batch = self.evaluate(next(iter(train_data[0])))
        l2_norm = 0.01
        aggregate_clip = aggregate_fn.build_aggregate_and_clip(
            norm_bound=l2_norm)
        trainer = build_federated_averaging_process_attacked(
            model_fn, stateful_delta_aggregate_fn=aggregate_clip)

        state = trainer.initialize()
        initial_weights = state.model.trainable

        state, _ = trainer.next(state, train_data, malicious_data,
                                client_type_list)

        weights_delta = tf.nest.map_structure(tf.subtract,
                                              state.model.trainable._asdict(),
                                              initial_weights._asdict())

        self.assertLess(attacked_fedavg._get_norm(weights_delta),
                        l2_norm * 1.01)
示例#2
0
    def test_attack(self):
        """Test whether an attacker is doing the right attack."""
        def model_fn():
            return tff.learning.from_compiled_keras_model(
                tff.simulation.models.mnist.create_simple_keras_model(),
                sample_batch)

        client_data = create_client_data()
        batch = client_data()
        train_data = [batch]
        malicious_data = [batch]
        client_type_list = [tf.constant(True)]
        sample_batch = self.evaluate(next(iter(train_data[0])))
        trainer = build_federated_averaging_process_attacked(
            model_fn,
            client_update_tf=attacked_fedavg.ClientExplicitBoosting(
                boost_factor=-1.0))

        state = trainer.initialize()
        initial_weights = state.model.trainable
        for _ in range(2):
            state, _ = trainer.next(state, train_data, malicious_data,
                                    client_type_list)

        self.assertAllClose(initial_weights, state.model.trainable)
示例#3
0
 def test_input_types(self):
     it_process = build_federated_averaging_process_attacked(_model_fn)
     self.assertIsInstance(it_process, tff.templates.IterativeProcess)
     federated_data_type = it_process.next.type_signature.parameter[1]
     self.assertEqual(str(federated_data_type),
                      '{<x=float32[?,784],y=int64[?,1]>*}@CLIENTS')
     self.assertEqual(str(it_process.next.type_signature.parameter[1]),
                      str(it_process.next.type_signature.parameter[2]))
     federated_bool_type = it_process.next.type_signature.parameter[3]
     self.assertEqual(str(federated_bool_type), '{bool}@CLIENTS')
示例#4
0
 def test_self_contained_example_custom_model(self):
   client_data = create_client_data()
   batch = client_data()
   train_data = [batch]
   malicious_data = [batch]
   client_type_list = [tf.constant(False)]
   trainer = build_federated_averaging_process_attacked(MnistModel)
   state = trainer.initialize()
   losses = []
   for _ in range(2):
     state, outputs = trainer.next(state, train_data, malicious_data,
                                   client_type_list)
     losses.append(outputs.loss)
   self.assertLess(losses[1], losses[0])
示例#5
0
 def test_aggregate_and_clip(self):
   """Test whether the norm clipping is done successfully."""
   client_data = create_client_data()
   batch = client_data()
   train_data = [batch]
   malicious_data = [batch]
   client_type_list = [tf.constant(False)]
   l2_norm = 0.01
   aggregate_clip = aggregate_fn.build_aggregate_and_clip(norm_bound=l2_norm)
   trainer = build_federated_averaging_process_attacked(
       _model_fn, stateful_delta_aggregate_fn=aggregate_clip)
   state = trainer.initialize()
   initial_weights = state.model.trainable
   state, _ = trainer.next(state, train_data, malicious_data, client_type_list)
   weights_delta = tf.nest.map_structure(tf.subtract, state.model.trainable,
                                         initial_weights)
   self.assertLess(attacked_fedavg._get_norm(weights_delta), l2_norm * 1.01)
    def test_attack(self):
        """Test whether an attacker is doing the right attack."""
        client_data = create_client_data()
        batch = client_data()
        train_data = [batch]
        malicious_data = [batch]
        client_type_list = [tf.constant(True)]
        trainer = build_federated_averaging_process_attacked(
            _model_fn,
            client_update_tf=attacked_fedavg.ClientExplicitBoosting(
                boost_factor=-1.0))
        state = trainer.initialize()
        initial_weights = state.model.trainable
        for _ in range(2):
            state, _ = trainer.next(state, train_data, malicious_data,
                                    client_type_list)

        self.assertAllClose(initial_weights, state.model.trainable)
示例#7
0
  def test_attack(self):
    """Test whether an attacker is doing the right attack."""
    self.skipTest('b/150215351 This test became flaky after TF change which '
                  'removed variable reads from control_outputs.')
    client_data = create_client_data()
    batch = client_data()
    train_data = [batch]
    malicious_data = [batch]
    client_type_list = [tf.constant(True)]
    trainer = build_federated_averaging_process_attacked(
        _model_fn,
        client_update_tf=attacked_fedavg.ClientExplicitBoosting(
            boost_factor=-1.0))
    state = trainer.initialize()
    initial_weights = state.model.trainable
    for _ in range(2):
      state, _ = trainer.next(state, train_data, malicious_data,
                              client_type_list)

    self.assertAllClose(initial_weights, state.model.trainable)
示例#8
0
 def test_dp_fed_mean(self):
     """Test whether the norm clipping is done successfully."""
     client_data = create_client_data()
     batch = client_data()
     train_data = [batch]
     malicious_data = [batch]
     client_type_list = [tf.constant(False)]
     l2_norm = 0.01
     query = tensorflow_privacy.GaussianAverageQuery(l2_norm, 0.0, 1.0)
     dp_aggregate_fn = tff.utils.build_dp_aggregate_process(
         tff.learning.framework.weights_type_from_model(
             _model_fn).trainable, query)
     trainer = build_federated_averaging_process_attacked(
         _model_fn, aggregation_process=dp_aggregate_fn)
     state = trainer.initialize()
     initial_weights = state.model.trainable
     state, _ = trainer.next(state, train_data, malicious_data,
                             client_type_list)
     weights_delta = tf.nest.map_structure(tf.subtract,
                                           state.model.trainable,
                                           initial_weights)
     self.assertLess(attacked_fedavg._get_norm(weights_delta),
                     l2_norm * 1.1)
示例#9
0
    def test_self_contained_example_keras_model(self):
        def model_fn():
            return tff.learning.from_compiled_keras_model(
                tff.simulation.models.mnist.create_simple_keras_model(),
                sample_batch)

        client_data = create_client_data()
        batch = client_data()
        train_data = [batch]
        malicious_data = [batch]
        client_type_list = [tf.constant(False)]

        sample_batch = self.evaluate(next(iter(train_data[0])))

        trainer = build_federated_averaging_process_attacked(model_fn)
        state = trainer.initialize()
        losses = []
        for _ in range(2):
            state, outputs = trainer.next(state, train_data, malicious_data,
                                          client_type_list)
            # Track the loss.
            losses.append(outputs.loss)
        self.assertLess(losses[1], losses[0])
示例#10
0
 def test_malicious_nodes(self):
   """Test whether the norm clipping is done successfully."""
   l2_norm = 0.01
   aggregate_clip = aggregate_fn.build_aggregate_and_clip(norm_bound=l2_norm)
   trainer = build_federated_averaging_process_attacked(
       _model_fn, stateful_delta_aggregate_fn=aggregate_clip)
   state = trainer.initialize()
   initial_weights = state.model.trainable
   
   emnist_train, _ = tff.simulation.datasets.emnist.load_data(
     only_digits=True)
   dataset_malicious, target_x, target_y = em.load_malicious_dataset(30)
   
   federated_train_data, federated_malicious_data, client_type_list = \
       em.sample_clients_with_malicious(
           emnist_train, client_ids=emnist_train.client_ids,
           dataset_malicious=dataset_malicious,
           num_clients=5, with_attack=1)
   
   state, _ = trainer.next(state, federated_train_data, federated_malicious_data, client_type_list)
   weights_delta = tf.nest.map_structure(tf.subtract,
                                         state.model.trainable._asdict(),
                                         initial_weights._asdict())
   self.assertLess(attacked_fedavg._get_norm(weights_delta), l2_norm * 1.01)
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')
  config = tf.compat.v1.ConfigProto()
  config.graph_options.rewrite_options.layout_optimizer = 2
  tf.compat.v1.enable_eager_execution(config)

  np.random.seed(FLAGS.random_seed)

  flag_dict = FLAGS.flag_values_dict()
  configs = '-'.join(
      ['{}={}'.format(k, flag_dict[k]) for k in keys if k != 'root_output_dir'])
  file_name = 'log' + configs
  create_if_not_exists(FLAGS.root_output_dir)
  file_handle = open(os.path.join(FLAGS.root_output_dir, file_name), 'w')

  global_step = tf.Variable(1, name='global_step', dtype=tf.int64)
  file_writer = tf.summary.create_file_writer(
      os.path.join(FLAGS.root_output_dir))
  file_writer.set_as_default()
  write_print(file_handle, '=======configurations========')
  write_print(file_handle, configs)
  write_print(file_handle, '=======configurations========')
  # prepare dataset.
  write_print(file_handle, 'Loading Dataset!')
  emnist_train, _ = tff.simulation.datasets.emnist.load_data(
      only_digits=FLAGS.only_digits)

  # prepare test set
  write_print(file_handle, 'Loading Test Set!')
  test_image, test_label = load_test_data()

  # load malicious dataset
  write_print(file_handle, 'Loading malicious dataset!')
  dataset_malicious, target_x, target_y = load_malicious_dataset(FLAGS.task_num)

  # prepare model_fn.
  example_dataset = preprocess(
      emnist_train.create_tf_dataset_for_client(emnist_train.client_ids[0]))
  input_spec = example_dataset.element_spec

  def model_fn():
    keras_model = create_keras_model()
    return tff.learning.from_keras_model(
        keras_model,
        input_spec=input_spec,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

  # define server optimizer
  nesterov = True if FLAGS.server_momentum != 0 else False

  def server_optimizer_fn():
    return tf.keras.optimizers.SGD(
        learning_rate=FLAGS.server_learning_rate,
        momentum=FLAGS.server_momentum,
        nesterov=nesterov)

  # build interative process
  write_print(file_handle, 'Building Iterative Process!')
  client_update_function = attacked_fedavg.ClientProjectBoost(
      boost_factor=float(FLAGS.num_clients_per_round),
      norm_bound=FLAGS.norm_bound,
      round_num=FLAGS.client_round_num)
  query = tensorflow_privacy.GaussianAverageQuery(FLAGS.l2_norm_clip,
                                                  FLAGS.mul_factor,
                                                  FLAGS.num_clients_per_round)
  dp_aggregate_fn = tff.utils.build_dp_aggregate_process(
      tff.learning.framework.weights_type_from_model(model_fn), query)
  iterative_process = attacked_fedavg.build_federated_averaging_process_attacked(
      model_fn=model_fn,
      aggregation_process=dp_aggregate_fn,
      client_update_tf=client_update_function,
      server_optimizer_fn=server_optimizer_fn)
  state = iterative_process.initialize()

  # training loop
  for cur_round in range(FLAGS.num_rounds):
    if cur_round % FLAGS.attack_freq == FLAGS.attack_freq // 2:
      with_attack = 1
      write_print(file_handle, 'Attacker appears!')
    else:
      with_attack = 0

    # sample clients and make federated dataset
    federated_train_data, federated_malicious_data, client_type_list = \
        sample_clients_with_malicious(
            emnist_train, client_ids=emnist_train.client_ids,
            dataset_malicious=dataset_malicious,
            num_clients=FLAGS.num_clients_per_round, with_attack=with_attack)

    # one round of attacked federated averaging
    write_print(file_handle, 'Round starts!')
    state, train_metrics = iterative_process.next(state, federated_train_data,
                                                  federated_malicious_data,
                                                  client_type_list)

    write_print(
        file_handle,
        'Training round {:2d}, train_metrics={}'.format(cur_round,
                                                        train_metrics))

    log_tfboard('train_acc', train_metrics['sparse_categorical_accuracy'],
                global_step)
    log_tfboard('train_loss', train_metrics['loss'], global_step)

    # evaluate current model on test data and malicious data
    if cur_round % FLAGS.evaluate_per_rounds == 0:
      test_metrics, test_metrics_target = evaluate(state, test_image,
                                                   test_label, target_x,
                                                   target_y)
      write_print(
          file_handle,
          'Evaluation round {:2d}, <sparse_categorical_accuracy={},loss={}>'
          .format(cur_round, test_metrics[1], test_metrics[0]))
      write_print(
          file_handle,
          'Evaluation round {:2d}, <sparse_categorical_accuracy={},loss={}>'
          .format(cur_round, test_metrics_target[1], test_metrics_target[0]))
      log_tfboard('test_acc', test_metrics[1], global_step)
      log_tfboard('test_loss', test_metrics[0], global_step)
      log_tfboard('test_acc_target', test_metrics_target[1], global_step)
      log_tfboard('test_loss_target', test_metrics_target[0], global_step)

    global_step.assign_add(1)