Beispiel #1
0
    def _check_one_layer(self, layer, input):
        if hasattr(layer, "autograd_grad_sample_hooks"):
            raise ValueError(f"Input layer already has hooks attached."
                             f"Please provide freshly constructed layer")

        nn.init.uniform_(layer.weight)
        nn.init.uniform_(layer.bias)

        output = layer(input)
        output.norm().backward()
        vanilla_run_grads = [
            p.grad.detach().clone() for p in layer.parameters()
            if p.requires_grad
        ]

        clipper = PerSampleGradientClipper(layer, 999)
        output = layer(input)
        output.norm().backward()
        clipper.step()
        private_run_grads = [
            p.grad.detach().clone() for p in layer.parameters()
            if p.requires_grad
        ]

        for vanilla_grad, private_grad in zip(vanilla_run_grads,
                                              private_run_grads):
            self.assertTrue(
                torch.allclose(vanilla_grad,
                               private_grad,
                               atol=10e-5,
                               rtol=10e-3))
Beispiel #2
0
    def _check_one_layer(self, layer, *args, **kwargs):
        if hasattr(layer, "autograd_grad_sample_hooks"):
            raise ValueError(f"Input layer already has hooks attached."
                             f"Please provide freshly constructed layer")

        self.validator.validate(layer)
        if hasattr(layer, "weight"):
            nn.init.uniform_(layer.weight)
        if hasattr(layer, "bias"):
            nn.init.uniform_(layer.bias)

        # run without DP
        self._reset_seeds()
        layer.zero_grad()
        output = layer(*args)
        if isinstance(output, tuple):
            output = output[0]
        output.norm().backward()
        vanilla_run_grads = [
            p.grad.detach().clone() for p in layer.parameters()
            if p.requires_grad
        ]

        # run with DP
        clipper = PerSampleGradientClipper(layer,
                                           999,
                                           batch_dim=kwargs.get(
                                               "batch_dim", 0))
        self._reset_seeds()
        layer.zero_grad()
        output = layer(*args)
        if isinstance(output, tuple):
            output = output[0]
        output.norm().backward()

        for param_name, param in layer.named_parameters():
            if param.requires_grad:
                self.assertTrue(
                    hasattr(param, "grad_sample"),
                    f"Per-sample gradients hasn't been computed for {param_name}",
                )

        clipper.step()

        private_run_grads = [
            p.grad.detach().clone() for p in layer.parameters()
            if p.requires_grad
        ]

        # compare
        for vanilla_grad, private_grad in zip(vanilla_run_grads,
                                              private_run_grads):
            self.assertTrue(
                torch.allclose(vanilla_grad,
                               private_grad,
                               atol=10e-5,
                               rtol=10e-3))
Beispiel #3
0
class PerSampleGradientClipper_test(unittest.TestCase):
    def setUp(self):
        self.DATA_SIZE = 64
        self.criterion = nn.CrossEntropyLoss()

        self.setUp_data()
        self.setUp_original_model()
        self.setUp_clipped_model(clip_value=0.003, run_clipper_step=True)

    def setUp_data(self):
        self.ds = FakeData(
            size=self.DATA_SIZE,
            image_size=(1, 35, 35),
            num_classes=10,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ]),
        )
        self.dl = DataLoader(self.ds, batch_size=self.DATA_SIZE)

    def setUp_original_model(self):
        self.original_model = SampleConvNet()
        for x, y in self.dl:
            logits = self.original_model(x)
            loss = self.criterion(logits, y)
            loss.backward()  # puts grad in self.original_model.parameters()
        self.original_grads_norms = torch.stack(
            [
                p.grad.norm()
                for p in self.original_model.parameters() if p.requires_grad
            ],
            dim=-1,
        )

    def setUp_clipped_model(self, clip_value=0.003, run_clipper_step=True):
        # Deep copy
        self.clipped_model = SampleConvNet()  # create the structure
        self.clipped_model.load_state_dict(
            self.original_model.state_dict())  # fill it

        # Intentionally clipping to a very small value
        self.clipper = PerSampleGradientClipper(self.clipped_model, clip_value)
        for x, y in self.dl:
            logits = self.clipped_model(x)
            loss = self.criterion(logits, y)
            loss.backward()  # puts grad in self.clipped_model.parameters()
            if run_clipper_step:
                self.clipper.step()
        self.clipped_grads_norms = torch.stack(
            [
                p.grad.norm()
                for p in self.clipped_model.parameters() if p.requires_grad
            ],
            dim=-1,
        )

    def test_clipped_grad_norm_is_smaller(self):
        """
        Test that grad are clipped and their value changes
        """
        for original_layer_norm, clipped_layer_norm in zip(
                self.original_grads_norms, self.clipped_grads_norms):
            self.assertLess(float(clipped_layer_norm),
                            float(original_layer_norm))

    def test_clipped_grad_norms_not_zero(self):
        """
        Test that grads aren't killed by clipping
        """
        allzeros = torch.zeros_like(self.clipped_grads_norms)
        self.assertFalse(torch.allclose(self.clipped_grads_norms, allzeros))

    def test_clipping_to_high_value_does_nothing(self):
        self.setUp_clipped_model(clip_value=9999,
                                 run_clipper_step=True)  # should be a no-op
        self.assertTrue(
            torch.allclose(self.original_grads_norms,
                           self.clipped_grads_norms))

    def test_grad_norms_untouched_without_clip_step(self):
        """
        Test that grad are not clipped until clipper.step() is called
        """
        self.setUp_clipped_model(clip_value=0.003, run_clipper_step=False)
        self.assertTrue(
            torch.allclose(self.original_grads_norms,
                           self.clipped_grads_norms))