Esempio n. 1
0
    def test_per_sample_api_compute_batch_size_not_pytreeable(self):
        @dataclass
        class NonPytreeableTuple:
            elem1: torch.Tensor
            elem2: torch.Tensor

        class CustomModule(nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = nn.Linear(5, 5)

            def forward(self, input1, input2):
                return self.linear(input1.elem1) + self.linear(input1.elem2)

        input = NonPytreeableTuple(torch.randn(4, 5), torch.randn(4, 5))
        model = CustomModule()
        with self.assertRaisesRegex(
                RuntimeError,
                "ExpandedWeights cannot compute the batch size from the inputs"
        ):
            call_for_per_sample_grads(model)(input, "")

        # would prefer for it to error because input is not pytree-able but that's hard to detect
        with self.assertRaisesRegex(
                RuntimeError,
                "Expected ExpandedWeights to have batch size matching input"):
            call_for_per_sample_grads(model)(input, torch.randn(5))

        model = CustomModule()  # TODO: functional call bug, sam will fix
        call_for_per_sample_grads(model)(input, torch.randn(4, 5))
        model = CustomModule()
        call_for_per_sample_grads(model, batch_size=4)(input, torch.randn(5))
Esempio n. 2
0
    def test_per_sample_api_compute_batch_size(self):
        class CustomModule(nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = nn.Linear(5, 5)

            def forward(self, input1, input2):
                return self.linear(input1) + self.linear(input2)

        module = CustomModule()
        input1 = torch.randn(4, 5)
        input2 = torch.randn(5, 5)

        with self.assertRaisesRegex(
                RuntimeError,
                "found at least one input with batch size 4 and one with batch size 5"
        ):
            call_for_per_sample_grads(module)(input1, input2)

        input2 = torch.randn(4, 5)
        call_for_per_sample_grads(module)(input1, input2)

        module = CustomModule()
        call_for_per_sample_grads(module)(input1, input2=input2)

        module = CustomModule()
        call_for_per_sample_grads(module)(input1=input1, input2=input2)
Esempio n. 3
0
    def _do_test(self, module, input):
        batch_size = input.shape[0]
        diff_input = input.dtype == torch.float or input.dtype == torch.double
        if diff_input:
            input.requires_grad_()
        with freeze_rng_state():
            # get per sample grads with ExpandedWeights context manager
            actual_res = call_for_per_sample_grads(module, batch_size, input).sum()
            actual_res.backward()
            actual_grads = []
            for param in module.parameters():
                actual_grads.append(param.grad_sample)
                del param.grad_sample
            if diff_input:
                actual_grads.append(input.grad.clone())
                input.grad = torch.zeros_like(input.grad)

            # get per sample grads with a for loop
            expected_res = torch.tensor(0., device=input.device, dtype=torch.double)
            expected_grads = []
            for i in range(batch_size):
                input_slice = input[i]
                diff_params = module.parameters()
                if diff_input:
                    diff_params = chain(diff_params, (input_slice,))
                res = module(input_slice.unsqueeze(0)).sum()
                out_grads = torch.autograd.grad(res, diff_params, torch.ones_like(res), allow_unused=True)
                expected_grads.append(out_grads)
                expected_res += res
            expected_grads = tuple(torch.stack(grad) for grad in zip(*expected_grads))
        self.assertEqual(actual_res, expected_res)
        [self.assertEqual(actual, expected) for (actual, expected) in zip(actual_grads, expected_grads)]
Esempio n. 4
0
    def _test_model(self,
                    model,
                    batch_size,
                    input,
                    device,
                    loss_reduction="sum"):
        model = model(10).to(device)
        targets = torch.randint(0, 10, (batch_size, ), device=device)
        criterion = CrossEntropyLoss(reduction=loss_reduction)
        result = call_for_per_sample_grads(
            model, loss_reduction=loss_reduction)(input)
        loss = criterion(result, targets)
        loss.backward()
        result = []
        for weight in model.parameters():
            result.append(weight.grad_sample)
            del weight.grad_sample

        expected = []
        for i in range(batch_size):
            loss = criterion(model(input[i].unsqueeze(0)),
                             targets[i].unsqueeze(0))
            expected.append(
                torch.autograd.grad(loss, model.parameters(),
                                    torch.ones_like(loss)))

        expected = [torch.stack(grad) for grad in zip(*expected)]
        for (res, exp) in zip(result, expected):
            self.assertEqual(res, exp, atol=1e-4, rtol=5e-5)
Esempio n. 5
0
    def _do_test_multi_input(self, module, input):
        class TestModule(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module

            def forward(self, input):
                return self.module(input) + self.module(input)

        batch_size = input.shape[0]
        with freeze_rng_state():
            # get per sample grads with ExpandedWeights context manager, calling .backward() twice
            test_module = TestModule(module)
            actual_res = call_for_per_sample_grads(test_module, batch_size,
                                                   input).sum()
            actual_res.backward()
            actual_grads = []
            for param in module.parameters():
                actual_grads.append(param.grad_sample)
                del param.grad_sample

            # get per sample grads with a for loop, running over the input twice
            expected_grads = []
            for i in range(batch_size):
                res = module(input[i].unsqueeze(0)).sum()
                expected_grads.append(
                    torch.autograd.grad(res, module.parameters(),
                                        torch.ones_like(res)))
            expected_grads = tuple(
                torch.stack(grad) for grad in zip(*expected_grads))
        assert [
            torch.allclose(actual, 2 * expected)
            for (actual, expected) in zip(actual_grads, expected_grads)
        ]
Esempio n. 6
0
    def _do_test(self, module, input):
        batch_size = input.shape[0]
        with freeze_rng_state():
            # get per sample grads with ExpandedWeights context manager
            actual_res = call_for_per_sample_grads(module, batch_size,
                                                   input).sum()
            actual_res.backward()
            actual_grads = []
            for param in module.parameters():
                actual_grads.append(param.grad_sample)
                del param.grad_sample

            # get per sample grads with a for loop
            expected_res = torch.tensor(0.)
            expected_grads = []
            for i in range(batch_size):
                res = module(input[i].unsqueeze(0)).sum()
                expected_grads.append(
                    torch.autograd.grad(res, module.parameters(),
                                        torch.ones_like(res)))
                expected_res += res
            expected_grads = tuple(
                torch.stack(grad) for grad in zip(*expected_grads))
        self.assertEqual(actual_res, expected_res)
        assert [
            torch.allclose(actual, expected)
            for (actual, expected) in zip(actual_grads, expected_grads)
        ]
Esempio n. 7
0
 def test_per_sample_api_failing(self):
     module = nn.Linear(10, 10)
     input = torch.randn(64, 10)
     with self.assertRaisesRegex(RuntimeError, r"Module passed must be nn.Module"):
         call_for_per_sample_grads("fail", 64, input)
     with self.assertRaisesRegex(RuntimeError, r"Batch size passed must be an integer"):
         call_for_per_sample_grads(module, 6.4, input)
     with self.assertRaisesRegex(RuntimeError, r"Batch size must be positive"):
         call_for_per_sample_grads(module, -64, input)
     with self.assertRaisesRegex(RuntimeError, r"incorrect for multiple calls"):
         loss = call_for_per_sample_grads(module, 64, input).sum()
         loss.backward()  # populate grad_sample fields
         call_for_per_sample_grads(module, 64, input)
Esempio n. 8
0
    def _do_test_multi_input(self, module, input):
        class TestModule(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module

            def forward(self, input):
                return self.module(input) + self.module(input)

        batch_size = input.shape[0]
        diff_input = input.dtype == torch.float or input.dtype == torch.double
        if diff_input:
            input.requires_grad_()
        with freeze_rng_state():
            # get per sample grads with ExpandedWeights context manager, calling .backward() twice
            test_module = TestModule(module)
            actual_res = call_for_per_sample_grads(
                test_module, loss_reduction="sum")(input).sum()
            actual_res.backward()
            actual_grads = []
            for param in module.parameters():
                actual_grads.append(param.grad_sample)
                del param.grad_sample
            if diff_input:
                actual_grads.append(input.grad.clone())
                input.grad = torch.zeros_like(input.grad)

            # get per sample grads with a for loop, running over the input twice
            expected_grads = []
            for i in range(batch_size):
                input_slice = input[i]
                diff_params = module.parameters()
                if diff_input:
                    diff_params = chain(diff_params, (input_slice, ))
                res = module(input_slice.unsqueeze(0)).sum()
                out_grads = torch.autograd.grad(res,
                                                diff_params,
                                                torch.ones_like(res),
                                                allow_unused=True)
                expected_grads.append(out_grads)
        expected_grads = tuple(
            torch.stack(grad) for grad in zip(*expected_grads))
        expected_grads = tuple(expected_grad
                               for expected_grad in expected_grads
                               if expected_grad is not None)
        assert [
            self.assertEqual(actual, 2 * expected)
            for (actual, expected) in zip(actual_grads, expected_grads)
        ]
    def test_small_model(self, device):
        def convnet(num_classes):
            return nn.Sequential(
                nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
                nn.AvgPool2d(kernel_size=2, stride=2),
                nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
                nn.AvgPool2d(kernel_size=2, stride=2),
                nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
                nn.AvgPool2d(kernel_size=2, stride=2),
                nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
                nn.AdaptiveAvgPool2d((1, 1)),
                nn.Flatten(start_dim=1, end_dim=-1),
                nn.Linear(128, num_classes, bias=True),
            )

        batch_size = 32
        model = convnet(10).to(device)
        input = torch.randn([batch_size, 3, 28, 28], device=device)
        targets = torch.randint(0, 10, (batch_size, ), device=device)
        criterion = CrossEntropyLoss(
            reduction='sum'
        )  # use a loss that doesn't average across the batch to test in a for loop
        result = call_for_per_sample_grads(model, batch_size, input)
        loss = criterion(result, targets)
        loss.backward()
        result = []
        for weight in model.parameters():
            result.append(weight.grad_sample)
            del weight.grad_sample

        expected = []
        for i in range(batch_size):
            loss = criterion(model(input[i].unsqueeze(0)),
                             targets[i].unsqueeze(0))
            expected.append(
                torch.autograd.grad(loss, model.parameters(),
                                    torch.ones_like(loss)))

        expected = [torch.stack(grad) for grad in zip(*expected)]
        for (res, exp) in zip(result, expected):
            self.assertEqual(res, exp, atol=1e-4, rtol=5e-5)
Esempio n. 10
0
    def test_per_sample_api_failing(self):
        module = nn.Linear(10, 10)
        input = torch.randn(64, 10)
        with self.assertRaisesRegex(RuntimeError,
                                    r"Module passed must be nn.Module"):
            call_for_per_sample_grads("fail")(input)
        with self.assertRaisesRegex(
                RuntimeError, r"Batch size passed must be None or an integer"):
            call_for_per_sample_grads(module, batch_size=6.4)(input)
        with self.assertRaisesRegex(RuntimeError,
                                    r"Batch size must be positive"):
            call_for_per_sample_grads(module, batch_size=-64)(input)
        with self.assertRaisesRegex(RuntimeError,
                                    r"incorrect for multiple calls"):
            loss = call_for_per_sample_grads(module)(input).sum()
            loss.backward()  # populate grad_sample fields
            call_for_per_sample_grads(module)(input)

        module = nn.Linear(10, 10)  # reset to not have grad_sample fields
        with self.assertRaisesRegex(
                RuntimeError,
                r"Expected loss_reduction argument to be sum or mean"):
            call_for_per_sample_grads(module, loss_reduction="")(input)