def test_forward_helper_failure_args(self, device): weight = torch.randn(5, 4, device=device) bias = torch.randn(5, device=device) with self.assertRaisesRegex(RuntimeError, r"do not support inputs that are also ExpandedWeights."): input = ExpandedWeight(torch.randn(3, 4, requires_grad=True), 3) expanded_args, expanded_kwargs = standard_kwargs(('bias',), (input, weight, bias)) forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) with self.assertRaisesRegex(RuntimeError, r"requires a Tensor as the first input"): expanded_args, expanded_kwargs = standard_kwargs(('bias',), (3, weight, bias)) forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) with self.assertRaisesRegex(RuntimeError, r"requires a batch dimension but got an input of size 0"): expanded_args, expanded_kwargs = standard_kwargs(('bias',), (torch.tensor(3), weight, bias)) forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) with self.assertRaisesRegex(RuntimeError, r"0 is not a valid batch size for Expanded Weights"): expanded_args, expanded_kwargs = standard_kwargs(('bias',), (torch.randn(0, 1, 2), weight, bias)) forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) input = torch.randn(3, 4) for (weight_batched, bias_batched) in product([True, False], [True, False]): if not weight_batched and not bias_batched: continue maybe_batched_weight = ExpandedWeight(weight.clone().requires_grad_(), 4) if weight_batched else weight maybe_batched_bias = ExpandedWeight(bias.clone().requires_grad_(), 4) if bias_batched else bias with self.assertRaisesRegex(RuntimeError, r"Expected ExpandedWeights to have batch size matching input"): expanded_args, expanded_kwargs = standard_kwargs(('bias',), (input, maybe_batched_weight, maybe_batched_bias)) forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
def test_forward_helper(self, device): input = torch.randn(3, 4, device=device) weight = torch.randn(5, 4, device=device) bias = torch.randn(5, device=device) for (weight_batched, bias_batched) in product([True, False], [True, False]): maybe_batched_weight = weight maybe_batched_bias = bias if weight_batched: maybe_batched_weight = ExpandedWeight( weight.clone().requires_grad_(), 3, loss_reduction="sum") if bias_batched: maybe_batched_bias = ExpandedWeight( bias.clone().requires_grad_(), 3, loss_reduction="sum") args = (input, maybe_batched_weight, maybe_batched_bias) expanded_args, expanded_kwargs = standard_kwargs(('bias', ), args) res = forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) expected = nn.functional.linear(input, weight, bias) self.assertEqual(res, expected) self.assertEqual(len(expanded_args), 2) assert expanded_args[0] is args[ 0] # avoids property checks in assertEquals assert expanded_args[1] is args[ 1] # avoids property checks in assertEquals self.assertEqual(len(expanded_kwargs), 1) assert expanded_kwargs['bias'] is args[ 2] # avoids property checks in assertEquals
def test_unpack_expanded_weight_or_tensor(self, device): input = torch.randn(3, requires_grad=True, device=device) self.assertEqual(input, unpack_expanded_weight_or_tensor(ExpandedWeight(input, 3))) input.requires_grad_(False) self.assertEqual(input, unpack_expanded_weight_or_tensor(input)) self.assertTrue(unpack_expanded_weight_or_tensor(4) is None)
def test_unpack_expanded_weight_or_tensor_with_custom_function(self, device): input = torch.randn(3, requires_grad=True, device=device) self.assertTrue(unpack_expanded_weight_or_tensor(ExpandedWeight(input, 3), lambda x: x is input)) input.requires_grad_(False) self.assertTrue(unpack_expanded_weight_or_tensor(input, lambda x: x is input)) self.assertTrue(unpack_expanded_weight_or_tensor(4, lambda x: x is input) is None)
def test_expanded_weight_error(self, device): batch_size = 3 sample_input = make_tensor((batch_size, 4), device, torch.float32, requires_grad=True) sample_weight = make_tensor((4), device, torch.float32, requires_grad=True) with self.assertRaisesRegex( RuntimeError, r"Expanded Weights encountered but cannot handle function"): torch.add(sample_input, ExpandedWeight(sample_weight, batch_size))
def test_set_grad_sample_if_exists(self, device): def test_fn(_): return True orig_weight = torch.randn(4, device=device, requires_grad=True) expanded_weight = ExpandedWeight(orig_weight, 3) set_grad_sample_if_exists(expanded_weight, test_fn) self.assertTrue(hasattr(orig_weight, 'grad_sample')) self.assertTrue(orig_weight.grad_sample) basic_tensor = torch.randn(4, device=device) set_grad_sample_if_exists(basic_tensor, test_fn) self.assertFalse(hasattr(basic_tensor, 'grad_sample')) non_tensor = 3 set_grad_sample_if_exists(non_tensor, test_fn) self.assertFalse(hasattr(non_tensor, 'grad_sample'))
def expanded_weight_or_clone(arg): return ExpandedWeight(torch.clone(arg), batch_size) if is_diff_tensor(arg) else clone_if_tensor(arg)
def expanded_weight_or_clone(arg): if is_diff_tensor(arg): return ExpandedWeight(torch.clone(arg), batch_size, loss_reduction) return clone_if_tensor(arg)