Ejemplo n.º 1
0
 def test_forward_helper_failure_args(self, device):
     weight = torch.randn(5, 4, device=device)
     bias = torch.randn(5, device=device)
     with self.assertRaisesRegex(RuntimeError, r"do not support inputs that are also ExpandedWeights."):
         input = ExpandedWeight(torch.randn(3, 4, requires_grad=True), 3)
         expanded_args, expanded_kwargs = standard_kwargs(('bias',), (input, weight, bias))
         forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
     with self.assertRaisesRegex(RuntimeError, r"requires a Tensor as the first input"):
         expanded_args, expanded_kwargs = standard_kwargs(('bias',), (3, weight, bias))
         forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
     with self.assertRaisesRegex(RuntimeError, r"requires a batch dimension but got an input of size 0"):
         expanded_args, expanded_kwargs = standard_kwargs(('bias',), (torch.tensor(3), weight, bias))
         forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
     with self.assertRaisesRegex(RuntimeError, r"0 is not a valid batch size for Expanded Weights"):
         expanded_args, expanded_kwargs = standard_kwargs(('bias',), (torch.randn(0, 1, 2), weight, bias))
         forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
     input = torch.randn(3, 4)
     for (weight_batched, bias_batched) in product([True, False], [True, False]):
         if not weight_batched and not bias_batched:
             continue
         maybe_batched_weight = ExpandedWeight(weight.clone().requires_grad_(), 4) if weight_batched else weight
         maybe_batched_bias = ExpandedWeight(bias.clone().requires_grad_(), 4) if bias_batched else bias
         with self.assertRaisesRegex(RuntimeError, r"Expected ExpandedWeights to have batch size matching input"):
             expanded_args, expanded_kwargs = standard_kwargs(('bias',), (input, maybe_batched_weight, maybe_batched_bias))
             forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
Ejemplo n.º 2
0
    def test_forward_helper(self, device):
        input = torch.randn(3, 4, device=device)
        weight = torch.randn(5, 4, device=device)
        bias = torch.randn(5, device=device)
        for (weight_batched, bias_batched) in product([True, False],
                                                      [True, False]):
            maybe_batched_weight = weight
            maybe_batched_bias = bias
            if weight_batched:
                maybe_batched_weight = ExpandedWeight(
                    weight.clone().requires_grad_(), 3, loss_reduction="sum")
            if bias_batched:
                maybe_batched_bias = ExpandedWeight(
                    bias.clone().requires_grad_(), 3, loss_reduction="sum")
            args = (input, maybe_batched_weight, maybe_batched_bias)
            expanded_args, expanded_kwargs = standard_kwargs(('bias', ), args)
            res = forward_helper(nn.functional.linear, expanded_args,
                                 expanded_kwargs)
            expected = nn.functional.linear(input, weight, bias)
            self.assertEqual(res, expected)

            self.assertEqual(len(expanded_args), 2)
            assert expanded_args[0] is args[
                0]  # avoids property checks in assertEquals
            assert expanded_args[1] is args[
                1]  # avoids property checks in assertEquals
            self.assertEqual(len(expanded_kwargs), 1)
            assert expanded_kwargs['bias'] is args[
                2]  # avoids property checks in assertEquals
Ejemplo n.º 3
0
    def test_unpack_expanded_weight_or_tensor(self, device):
        input = torch.randn(3, requires_grad=True, device=device)
        self.assertEqual(input, unpack_expanded_weight_or_tensor(ExpandedWeight(input, 3)))

        input.requires_grad_(False)
        self.assertEqual(input, unpack_expanded_weight_or_tensor(input))
        self.assertTrue(unpack_expanded_weight_or_tensor(4) is None)
Ejemplo n.º 4
0
    def test_unpack_expanded_weight_or_tensor_with_custom_function(self, device):
        input = torch.randn(3, requires_grad=True, device=device)
        self.assertTrue(unpack_expanded_weight_or_tensor(ExpandedWeight(input, 3), lambda x: x is input))

        input.requires_grad_(False)
        self.assertTrue(unpack_expanded_weight_or_tensor(input, lambda x: x is input))
        self.assertTrue(unpack_expanded_weight_or_tensor(4, lambda x: x is input) is None)
Ejemplo n.º 5
0
 def test_expanded_weight_error(self, device):
     batch_size = 3
     sample_input = make_tensor((batch_size, 4),
                                device,
                                torch.float32,
                                requires_grad=True)
     sample_weight = make_tensor((4),
                                 device,
                                 torch.float32,
                                 requires_grad=True)
     with self.assertRaisesRegex(
             RuntimeError,
             r"Expanded Weights encountered but cannot handle function"):
         torch.add(sample_input, ExpandedWeight(sample_weight, batch_size))
Ejemplo n.º 6
0
    def test_set_grad_sample_if_exists(self, device):
        def test_fn(_):
            return True

        orig_weight = torch.randn(4, device=device, requires_grad=True)
        expanded_weight = ExpandedWeight(orig_weight, 3)
        set_grad_sample_if_exists(expanded_weight, test_fn)
        self.assertTrue(hasattr(orig_weight, 'grad_sample'))
        self.assertTrue(orig_weight.grad_sample)

        basic_tensor = torch.randn(4, device=device)
        set_grad_sample_if_exists(basic_tensor, test_fn)
        self.assertFalse(hasattr(basic_tensor, 'grad_sample'))

        non_tensor = 3
        set_grad_sample_if_exists(non_tensor, test_fn)
        self.assertFalse(hasattr(non_tensor, 'grad_sample'))
Ejemplo n.º 7
0
 def expanded_weight_or_clone(arg):
     return ExpandedWeight(torch.clone(arg), batch_size) if is_diff_tensor(arg) else clone_if_tensor(arg)
Ejemplo n.º 8
0
 def expanded_weight_or_clone(arg):
     if is_diff_tensor(arg):
         return ExpandedWeight(torch.clone(arg), batch_size, loss_reduction)
     return clone_if_tensor(arg)