示例#1
0
 def test_forward_helper_failure_args(self, device):
     weight = torch.randn(5, 4, device=device)
     bias = torch.randn(5, device=device)
     with self.assertRaisesRegex(RuntimeError, r"do not support inputs that are also ExpandedWeights."):
         input = ExpandedWeight(torch.randn(3, 4, requires_grad=True), 3)
         expanded_args, expanded_kwargs = standard_kwargs(('bias',), (input, weight, bias))
         forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
     with self.assertRaisesRegex(RuntimeError, r"requires a Tensor as the first input"):
         expanded_args, expanded_kwargs = standard_kwargs(('bias',), (3, weight, bias))
         forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
     with self.assertRaisesRegex(RuntimeError, r"requires a batch dimension but got an input of size 0"):
         expanded_args, expanded_kwargs = standard_kwargs(('bias',), (torch.tensor(3), weight, bias))
         forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
     with self.assertRaisesRegex(RuntimeError, r"0 is not a valid batch size for Expanded Weights"):
         expanded_args, expanded_kwargs = standard_kwargs(('bias',), (torch.randn(0, 1, 2), weight, bias))
         forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
     input = torch.randn(3, 4)
     for (weight_batched, bias_batched) in product([True, False], [True, False]):
         if not weight_batched and not bias_batched:
             continue
         maybe_batched_weight = ExpandedWeight(weight.clone().requires_grad_(), 4) if weight_batched else weight
         maybe_batched_bias = ExpandedWeight(bias.clone().requires_grad_(), 4) if bias_batched else bias
         with self.assertRaisesRegex(RuntimeError, r"Expected ExpandedWeights to have batch size matching input"):
             expanded_args, expanded_kwargs = standard_kwargs(('bias',), (input, maybe_batched_weight, maybe_batched_bias))
             forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
示例#2
0
    def test_forward_helper(self, device):
        input = torch.randn(3, 4, device=device)
        weight = torch.randn(5, 4, device=device)
        bias = torch.randn(5, device=device)
        for (weight_batched, bias_batched) in product([True, False],
                                                      [True, False]):
            maybe_batched_weight = weight
            maybe_batched_bias = bias
            if weight_batched:
                maybe_batched_weight = ExpandedWeight(
                    weight.clone().requires_grad_(), 3, loss_reduction="sum")
            if bias_batched:
                maybe_batched_bias = ExpandedWeight(
                    bias.clone().requires_grad_(), 3, loss_reduction="sum")
            args = (input, maybe_batched_weight, maybe_batched_bias)
            expanded_args, expanded_kwargs = standard_kwargs(('bias', ), args)
            res = forward_helper(nn.functional.linear, expanded_args,
                                 expanded_kwargs)
            expected = nn.functional.linear(input, weight, bias)
            self.assertEqual(res, expected)

            self.assertEqual(len(expanded_args), 2)
            assert expanded_args[0] is args[
                0]  # avoids property checks in assertEquals
            assert expanded_args[1] is args[
                1]  # avoids property checks in assertEquals
            self.assertEqual(len(expanded_kwargs), 1)
            assert expanded_kwargs['bias'] is args[
                2]  # avoids property checks in assertEquals