def test_aggregated_updates_as_input_fails(self): """Expect per-example gradients as input to this transform.""" dp_agg = privacy.differentially_private_aggregate(l2_norm_clip=0.1, noise_multiplier=1.1, seed=2021) state = dp_agg.init(self.params) mean_grads = jax.tree_map(lambda g: g.mean(0), self.per_eg_grads) with self.assertRaises(ValueError): dp_agg.update(mean_grads, state, self.params)
def test_no_privacy(self): """l2_norm_clip=MAX_FLOAT32 and noise_multiplier=0 should recover SGD.""" dp_agg = privacy.differentially_private_aggregate( l2_norm_clip=jnp.finfo(jnp.float32).max, noise_multiplier=0., seed=0) state = dp_agg.init(self.params) update_fn = self.variant(dp_agg.update) mean_grads = jax.tree_map(lambda g: g.mean(0), self.per_eg_grads) for _ in range(3): updates, state = update_fn(self.per_eg_grads, state) chex.assert_tree_all_close(updates, mean_grads)
def test_noise_multiplier(self, l2_norm_clip, noise_multiplier): """Standard dev. of noise should be l2_norm_clip * noise_multiplier.""" dp_agg = privacy.differentially_private_aggregate( l2_norm_clip=l2_norm_clip, noise_multiplier=noise_multiplier, seed=1337) state = dp_agg.init(None) update_fn = self.variant(dp_agg.update) expected_std = l2_norm_clip * noise_multiplier grads = [jnp.ones((1, 100, 100))] # batch size 1 for _ in range(3): updates, state = update_fn(grads, state) chex.assert_tree_all_close(expected_std, jnp.std(updates[0]), atol=0.1 * expected_std)
def dpsgd( learning_rate: ScalarOrSchedule, l2_norm_clip: float, noise_multiplier: float, seed: int, momentum: Optional[float] = None, nesterov: bool = False ) -> base.GradientTransformation: """The DPSGD optimiser. Differential privacy is a standard for privacy guarantees of algorithms learning from aggregate databases including potentially sensitive information. DPSGD offers protection against a strong adversary with full knowledge of the training mechanism and access to the model’s parameters. WARNING: This `GradientTransformation` expects input updates to have a batch dimension on the 0th axis. That is, this function expects per-example gradients as input (which are easy to obtain in JAX using `jax.vmap`). References: Abadi et al, 2016: https://arxiv.org/abs/1607.00133 Args: learning_rate: this is a fixed global scaling factor. l2_norm_clip: maximum L2 norm of the per-example gradients. noise_multiplier: ratio of standard deviation to the clipping norm. seed: initial seed used for the jax.random.PRNGKey momentum: (default `None`), the `decay` rate used by the momentum term, when it is set to `None`, then momentum is not used at all. nesterov (default `False`): whether nesterov momentum is used. Returns: A `GradientTransformation`. """ return combine.chain( privacy.differentially_private_aggregate( l2_norm_clip=l2_norm_clip, noise_multiplier=noise_multiplier, seed=seed), (transform.trace(decay=momentum, nesterov=nesterov) if momentum is not None else base.identity()), _scale_by_learning_rate(learning_rate) )
def test_clipping_norm(self, l2_norm_clip): dp_agg = privacy.differentially_private_aggregate( l2_norm_clip=l2_norm_clip, noise_multiplier=0., seed=42) state = dp_agg.init(self.params) update_fn = self.variant(dp_agg.update) # Shape of the three arrays below is (self.batch_size, ) norms = [ jnp.linalg.norm(g.reshape(self.batch_size, -1), axis=1) for g in jax.tree_leaves(self.per_eg_grads) ] global_norms = jnp.linalg.norm(jnp.stack(norms), axis=0) divisors = jnp.maximum(global_norms / l2_norm_clip, 1.) # Since the values of all the parameters are the same within each example, # we can easily compute what the values should be: expected_val = jnp.mean(jnp.arange(self.batch_size) / divisors) expected_tree = jax.tree_map( lambda p: jnp.broadcast_to(expected_val, p.shape), self.params) for _ in range(3): updates, state = update_fn(self.per_eg_grads, state, self.params) chex.assert_tree_all_close(updates, expected_tree)