예제 #1
0
 def test_aggregated_updates_as_input_fails(self):
     """Expect per-example gradients as input to this transform."""
     dp_agg = privacy.differentially_private_aggregate(l2_norm_clip=0.1,
                                                       noise_multiplier=1.1,
                                                       seed=2021)
     state = dp_agg.init(self.params)
     mean_grads = jax.tree_map(lambda g: g.mean(0), self.per_eg_grads)
     with self.assertRaises(ValueError):
         dp_agg.update(mean_grads, state, self.params)
예제 #2
0
    def test_no_privacy(self):
        """l2_norm_clip=MAX_FLOAT32 and noise_multiplier=0 should recover SGD."""
        dp_agg = privacy.differentially_private_aggregate(
            l2_norm_clip=jnp.finfo(jnp.float32).max,
            noise_multiplier=0.,
            seed=0)
        state = dp_agg.init(self.params)
        update_fn = self.variant(dp_agg.update)
        mean_grads = jax.tree_map(lambda g: g.mean(0), self.per_eg_grads)

        for _ in range(3):
            updates, state = update_fn(self.per_eg_grads, state)
            chex.assert_tree_all_close(updates, mean_grads)
예제 #3
0
    def test_noise_multiplier(self, l2_norm_clip, noise_multiplier):
        """Standard dev. of noise should be l2_norm_clip * noise_multiplier."""
        dp_agg = privacy.differentially_private_aggregate(
            l2_norm_clip=l2_norm_clip,
            noise_multiplier=noise_multiplier,
            seed=1337)
        state = dp_agg.init(None)
        update_fn = self.variant(dp_agg.update)
        expected_std = l2_norm_clip * noise_multiplier

        grads = [jnp.ones((1, 100, 100))]  # batch size 1
        for _ in range(3):
            updates, state = update_fn(grads, state)
            chex.assert_tree_all_close(expected_std,
                                       jnp.std(updates[0]),
                                       atol=0.1 * expected_std)
예제 #4
0
파일: alias.py 프로젝트: wanglouis49/optax
def dpsgd(
    learning_rate: ScalarOrSchedule,
    l2_norm_clip: float,
    noise_multiplier: float,
    seed: int,
    momentum: Optional[float] = None,
    nesterov: bool = False
) -> base.GradientTransformation:
  """The DPSGD optimiser.

  Differential privacy is a standard for privacy guarantees of algorithms
  learning from aggregate databases including potentially sensitive information.
  DPSGD offers protection against a strong adversary with full knowledge of the
  training mechanism and access to the model’s parameters.

  WARNING: This `GradientTransformation` expects input updates to have a batch
  dimension on the 0th axis. That is, this function expects per-example
  gradients as input (which are easy to obtain in JAX using `jax.vmap`).

  References:
    Abadi et al, 2016: https://arxiv.org/abs/1607.00133

  Args:
    learning_rate: this is a fixed global scaling factor.
    l2_norm_clip: maximum L2 norm of the per-example gradients.
    noise_multiplier: ratio of standard deviation to the clipping norm.
    seed: initial seed used for the jax.random.PRNGKey
    momentum: (default `None`), the `decay` rate used by the momentum term,
      when it is set to `None`, then momentum is not used at all.
    nesterov (default `False`): whether nesterov momentum is used.

  Returns:
    A `GradientTransformation`.
  """
  return combine.chain(
      privacy.differentially_private_aggregate(
          l2_norm_clip=l2_norm_clip,
          noise_multiplier=noise_multiplier,
          seed=seed),
      (transform.trace(decay=momentum, nesterov=nesterov)
       if momentum is not None else base.identity()),
      _scale_by_learning_rate(learning_rate)
  )
예제 #5
0
    def test_clipping_norm(self, l2_norm_clip):
        dp_agg = privacy.differentially_private_aggregate(
            l2_norm_clip=l2_norm_clip, noise_multiplier=0., seed=42)
        state = dp_agg.init(self.params)
        update_fn = self.variant(dp_agg.update)

        # Shape of the three arrays below is (self.batch_size, )
        norms = [
            jnp.linalg.norm(g.reshape(self.batch_size, -1), axis=1)
            for g in jax.tree_leaves(self.per_eg_grads)
        ]
        global_norms = jnp.linalg.norm(jnp.stack(norms), axis=0)
        divisors = jnp.maximum(global_norms / l2_norm_clip, 1.)
        # Since the values of all the parameters are the same within each example,
        # we can easily compute what the values should be:
        expected_val = jnp.mean(jnp.arange(self.batch_size) / divisors)
        expected_tree = jax.tree_map(
            lambda p: jnp.broadcast_to(expected_val, p.shape), self.params)

        for _ in range(3):
            updates, state = update_fn(self.per_eg_grads, state, self.params)
            chex.assert_tree_all_close(updates, expected_tree)