Example #1
0
    def test_unpack_expanded_weight_or_tensor_failure(self, device):
        input = torch.randn(3, requires_grad=True, device=device)
        with self.assertRaisesRegex(RuntimeError, r"does not support a mixture of ExpandedWeight parameters and normal Parameters"):
            unpack_expanded_weight_or_tensor(input)

        with self.assertRaisesRegex(RuntimeError, r"does not support a mixture of ExpandedWeight parameters and normal Parameters"):
            unpack_expanded_weight_or_tensor(input, lambda x: x is input)
Example #2
0
    def test_unpack_expanded_weight_or_tensor(self, device):
        input = torch.randn(3, requires_grad=True, device=device)
        self.assertEqual(input, unpack_expanded_weight_or_tensor(ExpandedWeight(input, 3)))

        input.requires_grad_(False)
        self.assertEqual(input, unpack_expanded_weight_or_tensor(input))
        self.assertTrue(unpack_expanded_weight_or_tensor(4) is None)
Example #3
0
    def test_unpack_expanded_weight_or_tensor_with_custom_function(self, device):
        input = torch.randn(3, requires_grad=True, device=device)
        self.assertTrue(unpack_expanded_weight_or_tensor(ExpandedWeight(input, 3), lambda x: x is input))

        input.requires_grad_(False)
        self.assertTrue(unpack_expanded_weight_or_tensor(input, lambda x: x is input))
        self.assertTrue(unpack_expanded_weight_or_tensor(4, lambda x: x is input) is None)