def _check_one_layer_with_criterion(self, layer, criterion, *args, **kwargs): self.validator.validate(layer) for name, param in layer.named_parameters(): if ("weight" in name) or ("bias" in name): nn.init.uniform_(param, -1.0, 1.0) # run without DP self._run_once(layer, criterion, *args) vanilla_run_grads = [ (name, p.grad.detach()) for (name, p) in layer.named_parameters() if p.requires_grad ] # run with DP clipper = PerSampleGradientClipper( layer, ConstantFlatClipper(1e9), batch_first=kwargs.get("batch_first", True), loss_reduction=criterion.reduction, ) self._run_once(layer, criterion, *args) for param_name, param in layer.named_parameters(): if param.requires_grad: self.assertTrue( hasattr(param, "grad_sample"), f"Per-sample gradients haven't been computed for {param_name}", ) clipper.clip_and_accumulate() clipper.pre_step() private_run_grads = [ (name, p.grad.detach()) for (name, p) in layer.named_parameters() if p.requires_grad ] # compare for (vanilla_name, vanilla_grad), (private_name, private_grad) in zip( vanilla_run_grads, private_run_grads ): assert vanilla_name == private_name self.assertTrue( torch.allclose(vanilla_grad, private_grad, atol=10e-5, rtol=10e-3), f"Gradient mismatch. Parameter: {layer}.{vanilla_name}, loss: {criterion.reduction}", ) clipper.close()
def _check_one_layer_with_criterion(self, layer, criterion, data, batch_first=True): clipper = PerSampleGradientClipper(layer, ConstantFlatClipper(1e9), batch_first=batch_first, loss_reduction=criterion.reduction) self._run_once(layer, criterion, data) computed_sample_grads = {} for (param_name, param) in layer.named_parameters(): computed_sample_grads[param_name] = param.grad_sample.detach() clipper.clip_and_accumulate() clipper.pre_step() clipper.close() batch_dim = 0 if batch_first else 1 data = data.transpose(0, batch_dim) for i, sample in enumerate(data): # simulate batch_size = 1 sample_data = sample.unsqueeze(batch_dim) self._run_once(layer, criterion, sample_data) for (param_name, param) in layer.named_parameters(): # grad we just computed with batch_size = 1 vanilla_per_sample_grad = param.grad # i-th line in grad_sample computed before computed_per_sample_grad = computed_sample_grads[param_name][i] self.assertTrue( torch.allclose( vanilla_per_sample_grad, computed_per_sample_grad, atol=10e-5, rtol=10e-3, ), f"Gradient mismatch. Parameter: {layer}.{param_name}, loss: {criterion.reduction}", )