def test_per_sample_api_compute_batch_size_not_pytreeable(self): @dataclass class NonPytreeableTuple: elem1: torch.Tensor elem2: torch.Tensor class CustomModule(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(5, 5) def forward(self, input1, input2): return self.linear(input1.elem1) + self.linear(input1.elem2) input = NonPytreeableTuple(torch.randn(4, 5), torch.randn(4, 5)) model = CustomModule() with self.assertRaisesRegex( RuntimeError, "ExpandedWeights cannot compute the batch size from the inputs" ): call_for_per_sample_grads(model)(input, "") # would prefer for it to error because input is not pytree-able but that's hard to detect with self.assertRaisesRegex( RuntimeError, "Expected ExpandedWeights to have batch size matching input"): call_for_per_sample_grads(model)(input, torch.randn(5)) model = CustomModule() # TODO: functional call bug, sam will fix call_for_per_sample_grads(model)(input, torch.randn(4, 5)) model = CustomModule() call_for_per_sample_grads(model, batch_size=4)(input, torch.randn(5))
def test_per_sample_api_compute_batch_size(self): class CustomModule(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(5, 5) def forward(self, input1, input2): return self.linear(input1) + self.linear(input2) module = CustomModule() input1 = torch.randn(4, 5) input2 = torch.randn(5, 5) with self.assertRaisesRegex( RuntimeError, "found at least one input with batch size 4 and one with batch size 5" ): call_for_per_sample_grads(module)(input1, input2) input2 = torch.randn(4, 5) call_for_per_sample_grads(module)(input1, input2) module = CustomModule() call_for_per_sample_grads(module)(input1, input2=input2) module = CustomModule() call_for_per_sample_grads(module)(input1=input1, input2=input2)
def _do_test(self, module, input): batch_size = input.shape[0] diff_input = input.dtype == torch.float or input.dtype == torch.double if diff_input: input.requires_grad_() with freeze_rng_state(): # get per sample grads with ExpandedWeights context manager actual_res = call_for_per_sample_grads(module, batch_size, input).sum() actual_res.backward() actual_grads = [] for param in module.parameters(): actual_grads.append(param.grad_sample) del param.grad_sample if diff_input: actual_grads.append(input.grad.clone()) input.grad = torch.zeros_like(input.grad) # get per sample grads with a for loop expected_res = torch.tensor(0., device=input.device, dtype=torch.double) expected_grads = [] for i in range(batch_size): input_slice = input[i] diff_params = module.parameters() if diff_input: diff_params = chain(diff_params, (input_slice,)) res = module(input_slice.unsqueeze(0)).sum() out_grads = torch.autograd.grad(res, diff_params, torch.ones_like(res), allow_unused=True) expected_grads.append(out_grads) expected_res += res expected_grads = tuple(torch.stack(grad) for grad in zip(*expected_grads)) self.assertEqual(actual_res, expected_res) [self.assertEqual(actual, expected) for (actual, expected) in zip(actual_grads, expected_grads)]
def _test_model(self, model, batch_size, input, device, loss_reduction="sum"): model = model(10).to(device) targets = torch.randint(0, 10, (batch_size, ), device=device) criterion = CrossEntropyLoss(reduction=loss_reduction) result = call_for_per_sample_grads( model, loss_reduction=loss_reduction)(input) loss = criterion(result, targets) loss.backward() result = [] for weight in model.parameters(): result.append(weight.grad_sample) del weight.grad_sample expected = [] for i in range(batch_size): loss = criterion(model(input[i].unsqueeze(0)), targets[i].unsqueeze(0)) expected.append( torch.autograd.grad(loss, model.parameters(), torch.ones_like(loss))) expected = [torch.stack(grad) for grad in zip(*expected)] for (res, exp) in zip(result, expected): self.assertEqual(res, exp, atol=1e-4, rtol=5e-5)
def _do_test_multi_input(self, module, input): class TestModule(nn.Module): def __init__(self, module): super().__init__() self.module = module def forward(self, input): return self.module(input) + self.module(input) batch_size = input.shape[0] with freeze_rng_state(): # get per sample grads with ExpandedWeights context manager, calling .backward() twice test_module = TestModule(module) actual_res = call_for_per_sample_grads(test_module, batch_size, input).sum() actual_res.backward() actual_grads = [] for param in module.parameters(): actual_grads.append(param.grad_sample) del param.grad_sample # get per sample grads with a for loop, running over the input twice expected_grads = [] for i in range(batch_size): res = module(input[i].unsqueeze(0)).sum() expected_grads.append( torch.autograd.grad(res, module.parameters(), torch.ones_like(res))) expected_grads = tuple( torch.stack(grad) for grad in zip(*expected_grads)) assert [ torch.allclose(actual, 2 * expected) for (actual, expected) in zip(actual_grads, expected_grads) ]
def _do_test(self, module, input): batch_size = input.shape[0] with freeze_rng_state(): # get per sample grads with ExpandedWeights context manager actual_res = call_for_per_sample_grads(module, batch_size, input).sum() actual_res.backward() actual_grads = [] for param in module.parameters(): actual_grads.append(param.grad_sample) del param.grad_sample # get per sample grads with a for loop expected_res = torch.tensor(0.) expected_grads = [] for i in range(batch_size): res = module(input[i].unsqueeze(0)).sum() expected_grads.append( torch.autograd.grad(res, module.parameters(), torch.ones_like(res))) expected_res += res expected_grads = tuple( torch.stack(grad) for grad in zip(*expected_grads)) self.assertEqual(actual_res, expected_res) assert [ torch.allclose(actual, expected) for (actual, expected) in zip(actual_grads, expected_grads) ]
def test_per_sample_api_failing(self): module = nn.Linear(10, 10) input = torch.randn(64, 10) with self.assertRaisesRegex(RuntimeError, r"Module passed must be nn.Module"): call_for_per_sample_grads("fail", 64, input) with self.assertRaisesRegex(RuntimeError, r"Batch size passed must be an integer"): call_for_per_sample_grads(module, 6.4, input) with self.assertRaisesRegex(RuntimeError, r"Batch size must be positive"): call_for_per_sample_grads(module, -64, input) with self.assertRaisesRegex(RuntimeError, r"incorrect for multiple calls"): loss = call_for_per_sample_grads(module, 64, input).sum() loss.backward() # populate grad_sample fields call_for_per_sample_grads(module, 64, input)
def _do_test_multi_input(self, module, input): class TestModule(nn.Module): def __init__(self, module): super().__init__() self.module = module def forward(self, input): return self.module(input) + self.module(input) batch_size = input.shape[0] diff_input = input.dtype == torch.float or input.dtype == torch.double if diff_input: input.requires_grad_() with freeze_rng_state(): # get per sample grads with ExpandedWeights context manager, calling .backward() twice test_module = TestModule(module) actual_res = call_for_per_sample_grads( test_module, loss_reduction="sum")(input).sum() actual_res.backward() actual_grads = [] for param in module.parameters(): actual_grads.append(param.grad_sample) del param.grad_sample if diff_input: actual_grads.append(input.grad.clone()) input.grad = torch.zeros_like(input.grad) # get per sample grads with a for loop, running over the input twice expected_grads = [] for i in range(batch_size): input_slice = input[i] diff_params = module.parameters() if diff_input: diff_params = chain(diff_params, (input_slice, )) res = module(input_slice.unsqueeze(0)).sum() out_grads = torch.autograd.grad(res, diff_params, torch.ones_like(res), allow_unused=True) expected_grads.append(out_grads) expected_grads = tuple( torch.stack(grad) for grad in zip(*expected_grads)) expected_grads = tuple(expected_grad for expected_grad in expected_grads if expected_grad is not None) assert [ self.assertEqual(actual, 2 * expected) for (actual, expected) in zip(actual_grads, expected_grads) ]
def test_small_model(self, device): def convnet(num_classes): return nn.Sequential( nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.AvgPool2d(kernel_size=2, stride=2), nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.AvgPool2d(kernel_size=2, stride=2), nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.AvgPool2d(kernel_size=2, stride=2), nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(start_dim=1, end_dim=-1), nn.Linear(128, num_classes, bias=True), ) batch_size = 32 model = convnet(10).to(device) input = torch.randn([batch_size, 3, 28, 28], device=device) targets = torch.randint(0, 10, (batch_size, ), device=device) criterion = CrossEntropyLoss( reduction='sum' ) # use a loss that doesn't average across the batch to test in a for loop result = call_for_per_sample_grads(model, batch_size, input) loss = criterion(result, targets) loss.backward() result = [] for weight in model.parameters(): result.append(weight.grad_sample) del weight.grad_sample expected = [] for i in range(batch_size): loss = criterion(model(input[i].unsqueeze(0)), targets[i].unsqueeze(0)) expected.append( torch.autograd.grad(loss, model.parameters(), torch.ones_like(loss))) expected = [torch.stack(grad) for grad in zip(*expected)] for (res, exp) in zip(result, expected): self.assertEqual(res, exp, atol=1e-4, rtol=5e-5)
def test_per_sample_api_failing(self): module = nn.Linear(10, 10) input = torch.randn(64, 10) with self.assertRaisesRegex(RuntimeError, r"Module passed must be nn.Module"): call_for_per_sample_grads("fail")(input) with self.assertRaisesRegex( RuntimeError, r"Batch size passed must be None or an integer"): call_for_per_sample_grads(module, batch_size=6.4)(input) with self.assertRaisesRegex(RuntimeError, r"Batch size must be positive"): call_for_per_sample_grads(module, batch_size=-64)(input) with self.assertRaisesRegex(RuntimeError, r"incorrect for multiple calls"): loss = call_for_per_sample_grads(module)(input).sum() loss.backward() # populate grad_sample fields call_for_per_sample_grads(module)(input) module = nn.Linear(10, 10) # reset to not have grad_sample fields with self.assertRaisesRegex( RuntimeError, r"Expected loss_reduction argument to be sum or mean"): call_for_per_sample_grads(module, loss_reduction="")(input)