def test_forward_helper_failure_args(self, device): weight = torch.randn(5, 4, device=device) bias = torch.randn(5, device=device) with self.assertRaisesRegex(RuntimeError, r"do not support inputs that are also ExpandedWeights."): input = ExpandedWeight(torch.randn(3, 4, requires_grad=True), 3) expanded_args, expanded_kwargs = standard_kwargs(('bias',), (input, weight, bias)) forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) with self.assertRaisesRegex(RuntimeError, r"requires a Tensor as the first input"): expanded_args, expanded_kwargs = standard_kwargs(('bias',), (3, weight, bias)) forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) with self.assertRaisesRegex(RuntimeError, r"requires a batch dimension but got an input of size 0"): expanded_args, expanded_kwargs = standard_kwargs(('bias',), (torch.tensor(3), weight, bias)) forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) with self.assertRaisesRegex(RuntimeError, r"0 is not a valid batch size for Expanded Weights"): expanded_args, expanded_kwargs = standard_kwargs(('bias',), (torch.randn(0, 1, 2), weight, bias)) forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) input = torch.randn(3, 4) for (weight_batched, bias_batched) in product([True, False], [True, False]): if not weight_batched and not bias_batched: continue maybe_batched_weight = ExpandedWeight(weight.clone().requires_grad_(), 4) if weight_batched else weight maybe_batched_bias = ExpandedWeight(bias.clone().requires_grad_(), 4) if bias_batched else bias with self.assertRaisesRegex(RuntimeError, r"Expected ExpandedWeights to have batch size matching input"): expanded_args, expanded_kwargs = standard_kwargs(('bias',), (input, maybe_batched_weight, maybe_batched_bias)) forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
def test_forward_helper(self, device): input = torch.randn(3, 4, device=device) weight = torch.randn(5, 4, device=device) bias = torch.randn(5, device=device) for (weight_batched, bias_batched) in product([True, False], [True, False]): maybe_batched_weight = weight maybe_batched_bias = bias if weight_batched: maybe_batched_weight = ExpandedWeight( weight.clone().requires_grad_(), 3, loss_reduction="sum") if bias_batched: maybe_batched_bias = ExpandedWeight( bias.clone().requires_grad_(), 3, loss_reduction="sum") args = (input, maybe_batched_weight, maybe_batched_bias) expanded_args, expanded_kwargs = standard_kwargs(('bias', ), args) res = forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) expected = nn.functional.linear(input, weight, bias) self.assertEqual(res, expected) self.assertEqual(len(expanded_args), 2) assert expanded_args[0] is args[ 0] # avoids property checks in assertEquals assert expanded_args[1] is args[ 1] # avoids property checks in assertEquals self.assertEqual(len(expanded_kwargs), 1) assert expanded_kwargs['bias'] is args[ 2] # avoids property checks in assertEquals