def _check_one_layer(self, layer, input): if hasattr(layer, "autograd_grad_sample_hooks"): raise ValueError(f"Input layer already has hooks attached." f"Please provide freshly constructed layer") nn.init.uniform_(layer.weight) nn.init.uniform_(layer.bias) output = layer(input) output.norm().backward() vanilla_run_grads = [ p.grad.detach().clone() for p in layer.parameters() if p.requires_grad ] clipper = PerSampleGradientClipper(layer, 999) output = layer(input) output.norm().backward() clipper.step() private_run_grads = [ p.grad.detach().clone() for p in layer.parameters() if p.requires_grad ] for vanilla_grad, private_grad in zip(vanilla_run_grads, private_run_grads): self.assertTrue( torch.allclose(vanilla_grad, private_grad, atol=10e-5, rtol=10e-3))
def _check_one_layer(self, layer, *args, **kwargs): if hasattr(layer, "autograd_grad_sample_hooks"): raise ValueError(f"Input layer already has hooks attached." f"Please provide freshly constructed layer") self.validator.validate(layer) if hasattr(layer, "weight"): nn.init.uniform_(layer.weight) if hasattr(layer, "bias"): nn.init.uniform_(layer.bias) # run without DP self._reset_seeds() layer.zero_grad() output = layer(*args) if isinstance(output, tuple): output = output[0] output.norm().backward() vanilla_run_grads = [ p.grad.detach().clone() for p in layer.parameters() if p.requires_grad ] # run with DP clipper = PerSampleGradientClipper(layer, 999, batch_dim=kwargs.get( "batch_dim", 0)) self._reset_seeds() layer.zero_grad() output = layer(*args) if isinstance(output, tuple): output = output[0] output.norm().backward() for param_name, param in layer.named_parameters(): if param.requires_grad: self.assertTrue( hasattr(param, "grad_sample"), f"Per-sample gradients hasn't been computed for {param_name}", ) clipper.step() private_run_grads = [ p.grad.detach().clone() for p in layer.parameters() if p.requires_grad ] # compare for vanilla_grad, private_grad in zip(vanilla_run_grads, private_run_grads): self.assertTrue( torch.allclose(vanilla_grad, private_grad, atol=10e-5, rtol=10e-3))
class PerSampleGradientClipper_test(unittest.TestCase): def setUp(self): self.DATA_SIZE = 64 self.criterion = nn.CrossEntropyLoss() self.setUp_data() self.setUp_original_model() self.setUp_clipped_model(clip_value=0.003, run_clipper_step=True) def setUp_data(self): self.ds = FakeData( size=self.DATA_SIZE, image_size=(1, 35, 35), num_classes=10, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ]), ) self.dl = DataLoader(self.ds, batch_size=self.DATA_SIZE) def setUp_original_model(self): self.original_model = SampleConvNet() for x, y in self.dl: logits = self.original_model(x) loss = self.criterion(logits, y) loss.backward() # puts grad in self.original_model.parameters() self.original_grads_norms = torch.stack( [ p.grad.norm() for p in self.original_model.parameters() if p.requires_grad ], dim=-1, ) def setUp_clipped_model(self, clip_value=0.003, run_clipper_step=True): # Deep copy self.clipped_model = SampleConvNet() # create the structure self.clipped_model.load_state_dict( self.original_model.state_dict()) # fill it # Intentionally clipping to a very small value self.clipper = PerSampleGradientClipper(self.clipped_model, clip_value) for x, y in self.dl: logits = self.clipped_model(x) loss = self.criterion(logits, y) loss.backward() # puts grad in self.clipped_model.parameters() if run_clipper_step: self.clipper.step() self.clipped_grads_norms = torch.stack( [ p.grad.norm() for p in self.clipped_model.parameters() if p.requires_grad ], dim=-1, ) def test_clipped_grad_norm_is_smaller(self): """ Test that grad are clipped and their value changes """ for original_layer_norm, clipped_layer_norm in zip( self.original_grads_norms, self.clipped_grads_norms): self.assertLess(float(clipped_layer_norm), float(original_layer_norm)) def test_clipped_grad_norms_not_zero(self): """ Test that grads aren't killed by clipping """ allzeros = torch.zeros_like(self.clipped_grads_norms) self.assertFalse(torch.allclose(self.clipped_grads_norms, allzeros)) def test_clipping_to_high_value_does_nothing(self): self.setUp_clipped_model(clip_value=9999, run_clipper_step=True) # should be a no-op self.assertTrue( torch.allclose(self.original_grads_norms, self.clipped_grads_norms)) def test_grad_norms_untouched_without_clip_step(self): """ Test that grad are not clipped until clipper.step() is called """ self.setUp_clipped_model(clip_value=0.003, run_clipper_step=False) self.assertTrue( torch.allclose(self.original_grads_norms, self.clipped_grads_norms))