Ejemplo n.º 1
0
    def test_self_contained_example_noclip(self, clip_norm=-1):

        client_data = _create_client_data()

        model = MnistModel()
        optimizer_fn = lambda: tf.keras.optimizers.SGD(learning_rate=0.1)
        losses = []
        for _ in range(2):
            optimizer = optimizer_fn()
            _initialize_optimizer_vars(model, optimizer)
            server_message = dp_fedavg.BroadcastMessage(
                model_weights=model.weights, dp_clip_norm=clip_norm)
            outputs = dp_fedavg.client_update(model, client_data(),
                                              server_message, optimizer)
            losses.append(outputs.model_output.numpy())

        self.assertAllEqual(int(outputs.client_weight.numpy()), 2)
        self.assertLess(losses[1], losses[0])
Ejemplo n.º 2
0
  def test_self_contained_example(self, clip_norm, simulation_flag):

    client_data = _create_client_data()

    model = MnistModel()
    optimizer_fn = lambda: tf.keras.optimizers.SGD(learning_rate=0.1)
    losses = []
    for _ in range(2):
      optimizer = optimizer_fn()
      _initialize_optimizer_vars(model, optimizer)
      server_message = dp_fedavg.BroadcastMessage(
          model_weights=model.weights, dp_clip_norm=clip_norm)
      outputs = dp_fedavg.client_update(model, client_data(), server_message,
                                        optimizer, simulation_flag)
      losses.append(outputs.model_output.numpy())
      weights_delta_norm = tf.linalg.global_norm(
          tf.nest.flatten(outputs.weights_delta))
      self.assertLessEqual(weights_delta_norm, clip_norm)

    self.assertAllEqual(int(outputs.client_weight.numpy()), 2)
    self.assertLess(losses[1], losses[0])